1//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering from high level async operations to async.coro
10// and async.runtime operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
16#include "mlir/Dialect/Async/Passes.h"
17
18#include "PassDetail.h"
19#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Async/IR/Async.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/ImplicitLocOpBuilder.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/RegionUtils.h"
30#include "llvm/Support/Debug.h"
31#include <optional>
32
33namespace mlir {
34#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
35#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
36#include "mlir/Dialect/Async/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::async;
41
42#define DEBUG_TYPE "async-to-async-runtime"
43// Prefix for functions outlined from `async.execute` op regions.
44static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
45
46namespace {
47
48class AsyncToAsyncRuntimePass
49 : public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
50public:
51 AsyncToAsyncRuntimePass() = default;
52 void runOnOperation() override;
53};
54
55} // namespace
56
57namespace {
58
59class AsyncFuncToAsyncRuntimePass
60 : public impl::AsyncFuncToAsyncRuntimePassBase<
61 AsyncFuncToAsyncRuntimePass> {
62public:
63 AsyncFuncToAsyncRuntimePass() = default;
64 void runOnOperation() override;
65};
66
67} // namespace
68
69/// Function targeted for coroutine transformation has two additional blocks at
70/// the end: coroutine cleanup and coroutine suspension.
71///
72/// async.await op lowering additionaly creates a resume block for each
73/// operation to enable non-blocking waiting via coroutine suspension.
74namespace {
75struct CoroMachinery {
76 func::FuncOp func;
77
78 // Async function returns an optional token, followed by some async values
79 //
80 // async.func @foo() -> !async.value<T> {
81 // %cst = arith.constant 42.0 : T
82 // return %cst: T
83 // }
84 // Async execute region returns a completion token, and an async value for
85 // each yielded value.
86 //
87 // %token, %result = async.execute -> !async.value<T> {
88 // %0 = arith.constant ... : T
89 // async.yield %0 : T
90 // }
91 std::optional<Value> asyncToken; // returned completion token
92 llvm::SmallVector<Value, 4> returnValues; // returned async values
93
94 Value coroHandle; // coroutine handle (!async.coro.getHandle value)
95 Block *entry; // coroutine entry block
96 std::optional<Block *> setError; // set returned values to error state
97 Block *cleanup; // coroutine cleanup block
98
99 // Coroutine cleanup block for destroy after the coroutine is resumed,
100 // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101 //
102 // This cleanup block is a duplicate of the cleanup block followed by the
103 // resume block. The purpose of having a duplicate cleanup block for destroy
104 // is to make the CFG clear so that the control flow analysis won't confuse.
105 //
106 // The overall structure of the lowered CFG can be the following,
107 //
108 // Entry (calling async.coro.suspend)
109 // | \
110 // Resume Destroy (duplicate of Cleanup)
111 // | |
112 // Cleanup |
113 // | /
114 // End (ends the corontine)
115 //
116 // If there is resume-specific cleanup logic, it can go into the Cleanup
117 // block but not the destroy block. Otherwise, it can fail block dominance
118 // check.
119 Block *cleanupForDestroy;
120 Block *suspend; // coroutine suspension block
121};
122} // namespace
123
124using FuncCoroMapPtr =
125 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
126
127/// Utility to partially update the regular function CFG to the coroutine CFG
128/// compatible with LLVM coroutines switched-resume lowering using
129/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
130/// that branches into preexisting entry block. Also inserts trailing blocks.
131///
132/// The result types of the passed `func` start with an optional `async.token`
133/// and be continued with some number of `async.value`s.
134///
135/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
136///
137/// - `entry` block sets up the coroutine.
138/// - `set_error` block sets completion token and async values state to error.
139/// - `cleanup` block cleans up the coroutine state.
140/// - `suspend block after the @llvm.coro.end() defines what value will be
141/// returned to the initial caller of a coroutine. Everything before the
142/// @llvm.coro.end() will be executed at every suspension point.
143///
144/// Coroutine structure (only the important bits):
145///
146/// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
147/// {
148/// ^entry(<function-arguments>):
149/// %token = <async token> : !async.token // create async runtime token
150/// %value = <async value> : !async.value<T> // create async value
151/// %id = async.coro.getId // create a coroutine id
152/// %hdl = async.coro.begin %id // create a coroutine handle
153/// cf.br ^preexisting_entry_block
154///
155/// /* preexisting blocks modified to branch to the cleanup block */
156///
157/// ^set_error: // this block created lazily only if needed (see code below)
158/// async.runtime.set_error %token : !async.token
159/// async.runtime.set_error %value : !async.value<T>
160/// cf.br ^cleanup
161///
162/// ^cleanup:
163/// async.coro.free %hdl // delete the coroutine state
164/// cf.br ^suspend
165///
166/// ^suspend:
167/// async.coro.end %hdl // marks the end of a coroutine
168/// return %token, %value : !async.token, !async.value<T>
169/// }
170///
171static CoroMachinery setupCoroMachinery(func::FuncOp func) {
172 assert(!func.getBlocks().empty() && "Function must have an entry block");
173
174 MLIRContext *ctx = func.getContext();
175 Block *entryBlock = &func.getBlocks().front();
176 Block *originalEntryBlock =
177 entryBlock->splitBlock(splitBefore: entryBlock->getOperations().begin());
178 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc: func->getLoc(), block: entryBlock);
179
180 // ------------------------------------------------------------------------ //
181 // Allocate async token/values that we will return from a ramp function.
182 // ------------------------------------------------------------------------ //
183
184 // We treat TokenType as state update marker to represent side-effects of
185 // async computations
186 bool isStateful = isa<TokenType>(Val: func.getResultTypes().front());
187
188 std::optional<Value> retToken;
189 if (isStateful)
190 retToken.emplace(args: builder.create<RuntimeCreateOp>(args: TokenType::get(ctx)));
191
192 llvm::SmallVector<Value, 4> retValues;
193 ArrayRef<Type> resValueTypes =
194 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195 for (auto resType : resValueTypes)
196 retValues.emplace_back(
197 Args: builder.create<RuntimeCreateOp>(args&: resType).getResult());
198
199 // ------------------------------------------------------------------------ //
200 // Initialize coroutine: get coroutine id and coroutine handle.
201 // ------------------------------------------------------------------------ //
202 auto coroIdOp = builder.create<CoroIdOp>(args: CoroIdType::get(ctx));
203 auto coroHdlOp =
204 builder.create<CoroBeginOp>(args: CoroHandleType::get(ctx), args: coroIdOp.getId());
205 builder.create<cf::BranchOp>(args&: originalEntryBlock);
206
207 Block *cleanupBlock = func.addBlock();
208 Block *cleanupBlockForDestroy = func.addBlock();
209 Block *suspendBlock = func.addBlock();
210
211 // ------------------------------------------------------------------------ //
212 // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
213 // ------------------------------------------------------------------------ //
214 auto buildCleanupBlock = [&](Block *cb) {
215 builder.setInsertionPointToStart(cb);
216 builder.create<CoroFreeOp>(args: coroIdOp.getId(), args: coroHdlOp.getHandle());
217
218 // Branch into the suspend block.
219 builder.create<cf::BranchOp>(args&: suspendBlock);
220 };
221 buildCleanupBlock(cleanupBlock);
222 buildCleanupBlock(cleanupBlockForDestroy);
223
224 // ------------------------------------------------------------------------ //
225 // Coroutine suspend block: mark the end of a coroutine and return allocated
226 // async token.
227 // ------------------------------------------------------------------------ //
228 builder.setInsertionPointToStart(suspendBlock);
229
230 // Mark the end of a coroutine: async.coro.end
231 builder.create<CoroEndOp>(args: coroHdlOp.getHandle());
232
233 // Return created optional `async.token` and `async.values` from the suspend
234 // block. This will be the return value of a coroutine ramp function.
235 SmallVector<Value, 4> ret;
236 if (retToken)
237 ret.push_back(Elt: *retToken);
238 llvm::append_range(C&: ret, R&: retValues);
239 builder.create<func::ReturnOp>(args&: ret);
240
241 // `async.await` op lowering will create resume blocks for async
242 // continuations, and will conditionally branch to cleanup or suspend blocks.
243
244 // The switch-resumed API based coroutine should be marked with
245 // presplitcoroutine attribute to mark the function as a coroutine.
246 func->setAttr(name: "passthrough", value: builder.getArrayAttr(
247 value: StringAttr::get(context: ctx, bytes: "presplitcoroutine")));
248
249 CoroMachinery machinery;
250 machinery.func = func;
251 machinery.asyncToken = retToken;
252 machinery.returnValues = retValues;
253 machinery.coroHandle = coroHdlOp.getHandle();
254 machinery.entry = entryBlock;
255 machinery.setError = std::nullopt; // created lazily only if needed
256 machinery.cleanup = cleanupBlock;
257 machinery.cleanupForDestroy = cleanupBlockForDestroy;
258 machinery.suspend = suspendBlock;
259 return machinery;
260}
261
262// Lazily creates `set_error` block only if it is required for lowering to the
263// runtime operations (see for example lowering of assert operation).
264static Block *setupSetErrorBlock(CoroMachinery &coro) {
265 if (coro.setError)
266 return *coro.setError;
267
268 coro.setError = coro.func.addBlock();
269 (*coro.setError)->moveBefore(block: coro.cleanup);
270
271 auto builder =
272 ImplicitLocOpBuilder::atBlockBegin(loc: coro.func->getLoc(), block: *coro.setError);
273
274 // Coroutine set_error block: set error on token and all returned values.
275 if (coro.asyncToken)
276 builder.create<RuntimeSetErrorOp>(args&: *coro.asyncToken);
277
278 for (Value retValue : coro.returnValues)
279 builder.create<RuntimeSetErrorOp>(args&: retValue);
280
281 // Branch into the cleanup block.
282 builder.create<cf::BranchOp>(args&: coro.cleanup);
283
284 return *coro.setError;
285}
286
287//===----------------------------------------------------------------------===//
288// async.execute op outlining to the coroutine functions.
289//===----------------------------------------------------------------------===//
290
291/// Outline the body region attached to the `async.execute` op into a standalone
292/// function.
293///
294/// Note that this is not reversible transformation.
295static std::pair<func::FuncOp, CoroMachinery>
296outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
297 ModuleOp module = execute->getParentOfType<ModuleOp>();
298
299 MLIRContext *ctx = module.getContext();
300 Location loc = execute.getLoc();
301
302 // Make sure that all constants will be inside the outlined async function to
303 // reduce the number of function arguments.
304 cloneConstantsIntoTheRegion(region&: execute.getBodyRegion());
305
306 // Collect all outlined function inputs.
307 SetVector<mlir::Value> functionInputs(llvm::from_range,
308 execute.getDependencies());
309 functionInputs.insert_range(R: execute.getBodyOperands());
310 getUsedValuesDefinedAbove(regions: execute.getBodyRegion(), values&: functionInputs);
311
312 // Collect types for the outlined function inputs and outputs.
313 auto typesRange = llvm::map_range(
314 C&: functionInputs, F: [](Value value) { return value.getType(); });
315 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
316 auto outputTypes = execute.getResultTypes();
317
318 auto funcType = FunctionType::get(context: ctx, inputs: inputTypes, results: outputTypes);
319 auto funcAttrs = ArrayRef<NamedAttribute>();
320
321 // TODO: Derive outlined function name from the parent FuncOp (support
322 // multiple nested async.execute operations).
323 func::FuncOp func =
324 func::FuncOp::create(location: loc, name: kAsyncFnPrefix, type: funcType, attrs: funcAttrs);
325 symbolTable.insert(symbol: func);
326
327 SymbolTable::setSymbolVisibility(symbol: func, vis: SymbolTable::Visibility::Private);
328 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, block: func.addEntryBlock());
329
330 // Prepare for coroutine conversion by creating the body of the function.
331 {
332 size_t numDependencies = execute.getDependencies().size();
333 size_t numOperands = execute.getBodyOperands().size();
334
335 // Await on all dependencies before starting to execute the body region.
336 for (size_t i = 0; i < numDependencies; ++i)
337 builder.create<AwaitOp>(args: func.getArgument(idx: i));
338
339 // Await on all async value operands and unwrap the payload.
340 SmallVector<Value, 4> unwrappedOperands(numOperands);
341 for (size_t i = 0; i < numOperands; ++i) {
342 Value operand = func.getArgument(idx: numDependencies + i);
343 unwrappedOperands[i] = builder.create<AwaitOp>(location: loc, args&: operand).getResult();
344 }
345
346 // Map from function inputs defined above the execute op to the function
347 // arguments.
348 IRMapping valueMapping;
349 valueMapping.map(from&: functionInputs, to: func.getArguments());
350 valueMapping.map(from: execute.getBodyRegion().getArguments(), to&: unwrappedOperands);
351
352 // Clone all operations from the execute operation body into the outlined
353 // function body.
354 for (Operation &op : execute.getBodyRegion().getOps())
355 builder.clone(op, mapper&: valueMapping);
356 }
357
358 // Adding entry/cleanup/suspend blocks.
359 CoroMachinery coro = setupCoroMachinery(func);
360
361 // Suspend async function at the end of an entry block, and resume it using
362 // Async resume operation (execution will be resumed in a thread managed by
363 // the async runtime).
364 {
365 cf::BranchOp branch = cast<cf::BranchOp>(Val: coro.entry->getTerminator());
366 builder.setInsertionPointToEnd(coro.entry);
367
368 // Save the coroutine state: async.coro.save
369 auto coroSaveOp =
370 builder.create<CoroSaveOp>(args: CoroStateType::get(ctx), args&: coro.coroHandle);
371
372 // Pass coroutine to the runtime to be resumed on a runtime managed
373 // thread.
374 builder.create<RuntimeResumeOp>(args&: coro.coroHandle);
375
376 // Add async.coro.suspend as a suspended block terminator.
377 builder.create<CoroSuspendOp>(args: coroSaveOp.getState(), args&: coro.suspend,
378 args: branch.getDest(), args&: coro.cleanupForDestroy);
379
380 branch.erase();
381 }
382
383 // Replace the original `async.execute` with a call to outlined function.
384 {
385 ImplicitLocOpBuilder callBuilder(loc, execute);
386 auto callOutlinedFunc = callBuilder.create<func::CallOp>(
387 args: func.getName(), args: execute.getResultTypes(), args: functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(values: callOutlinedFunc.getResults());
389 execute.erase();
390 }
391
392 return {func, coro};
393}
394
395//===----------------------------------------------------------------------===//
396// Convert async.create_group operation to async.runtime.create_group
397//===----------------------------------------------------------------------===//
398
399namespace {
400class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
401public:
402 using OpConversionPattern::OpConversionPattern;
403
404 LogicalResult
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override {
407 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
408 op, args: GroupType::get(ctx: op->getContext()), args: adaptor.getOperands());
409 return success();
410 }
411};
412} // namespace
413
414//===----------------------------------------------------------------------===//
415// Convert async.add_to_group operation to async.runtime.add_to_group.
416//===----------------------------------------------------------------------===//
417
418namespace {
419class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
420public:
421 using OpConversionPattern::OpConversionPattern;
422
423 LogicalResult
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
427 op, args: rewriter.getIndexType(), args: adaptor.getOperands());
428 return success();
429 }
430};
431} // namespace
432
433//===----------------------------------------------------------------------===//
434// Convert async.func, async.return and async.call operations to non-blocking
435// operations based on llvm coroutine
436//===----------------------------------------------------------------------===//
437
438namespace {
439
440//===----------------------------------------------------------------------===//
441// Convert async.func operation to func.func
442//===----------------------------------------------------------------------===//
443
444class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
445public:
446 AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
447 : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
448
449 LogicalResult
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452 Location loc = op->getLoc();
453
454 auto newFuncOp =
455 rewriter.create<func::FuncOp>(location: loc, args: op.getName(), args: op.getFunctionType());
456
457 SymbolTable::setSymbolVisibility(symbol: newFuncOp,
458 vis: SymbolTable::getSymbolVisibility(symbol: op));
459 // Copy over all attributes other than the name.
460 for (const auto &namedAttr : op->getAttrs()) {
461 if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
462 newFuncOp->setAttr(name: namedAttr.getName(), value: namedAttr.getValue());
463 }
464
465 rewriter.inlineRegionBefore(region&: op.getBody(), parent&: newFuncOp.getBody(),
466 before: newFuncOp.end());
467
468 CoroMachinery coro = setupCoroMachinery(newFuncOp);
469 (*coros)[newFuncOp] = coro;
470 // no initial suspend, we should hot-start
471
472 rewriter.eraseOp(op);
473 return success();
474 }
475
476private:
477 FuncCoroMapPtr coros;
478};
479
480//===----------------------------------------------------------------------===//
481// Convert async.call operation to func.call
482//===----------------------------------------------------------------------===//
483
484class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
485public:
486 AsyncCallOpLowering(MLIRContext *ctx)
487 : OpConversionPattern<async::CallOp>(ctx) {}
488
489 LogicalResult
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter) const override {
492 rewriter.replaceOpWithNewOp<func::CallOp>(
493 op, args: op.getCallee(), args: op.getResultTypes(), args: op.getOperands());
494 return success();
495 }
496};
497
498//===----------------------------------------------------------------------===//
499// Convert async.return operation to async.runtime operations.
500//===----------------------------------------------------------------------===//
501
502class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
503public:
504 AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
505 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
506
507 LogicalResult
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(Val: func);
512 if (funcCoro == coros->end())
513 return rewriter.notifyMatchFailure(
514 arg&: op, msg: "operation is not inside the async coroutine function");
515
516 Location loc = op->getLoc();
517 const CoroMachinery &coro = funcCoro->getSecond();
518 rewriter.setInsertionPointAfter(op);
519
520 // Store return values into the async values storage and switch async
521 // values state to available.
522 for (auto tuple : llvm::zip(t: adaptor.getOperands(), u: coro.returnValues)) {
523 Value returnValue = std::get<0>(t&: tuple);
524 Value asyncValue = std::get<1>(t&: tuple);
525 rewriter.create<RuntimeStoreOp>(location: loc, args&: returnValue, args&: asyncValue);
526 rewriter.create<RuntimeSetAvailableOp>(location: loc, args&: asyncValue);
527 }
528
529 if (coro.asyncToken)
530 // Switch the coroutine completion token to available state.
531 rewriter.create<RuntimeSetAvailableOp>(location: loc, args: *coro.asyncToken);
532
533 rewriter.eraseOp(op);
534 rewriter.create<cf::BranchOp>(location: loc, args: coro.cleanup);
535 return success();
536 }
537
538private:
539 FuncCoroMapPtr coros;
540};
541} // namespace
542
543//===----------------------------------------------------------------------===//
544// Convert async.await and async.await_all operations to the async.runtime.await
545// or async.runtime.await_and_resume operations.
546//===----------------------------------------------------------------------===//
547
548namespace {
549template <typename AwaitType, typename AwaitableType>
550class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
551 using AwaitAdaptor = typename AwaitType::Adaptor;
552
553public:
554 AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
555 bool shouldLowerBlockingWait)
556 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
558
559 LogicalResult
560 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
561 ConversionPatternRewriter &rewriter) const override {
562 // We can only await on one the `AwaitableType` (for `await` it can be
563 // a `token` or a `value`, for `await_all` it must be a `group`).
564 if (!isa<AwaitableType>(op.getOperand().getType()))
565 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
566
567 // Check if await operation is inside the coroutine function.
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
571
572 Location loc = op->getLoc();
573 Value operand = adaptor.getOperand();
574
575 Type i1 = rewriter.getI1Type();
576
577 // Delay lowering to block wait in case await op is inside async.execute
578 if (!isInCoroutine && !shouldLowerBlockingWait)
579 return failure();
580
581 // Inside regular functions we use the blocking wait operation to wait for
582 // the async object (token, value or group) to become available.
583 if (!isInCoroutine) {
584 ImplicitLocOpBuilder builder(loc, rewriter);
585 builder.create<RuntimeAwaitOp>(location: loc, args&: operand);
586
587 // Assert that the awaited operands is not in the error state.
588 Value isError = builder.create<RuntimeIsErrorOp>(args&: i1, args&: operand);
589 Value notError = builder.create<arith::XOrIOp>(
590 args&: isError, args: builder.create<arith::ConstantOp>(
591 location: loc, args&: i1, args: builder.getIntegerAttr(type: i1, value: 1)));
592
593 builder.create<cf::AssertOp>(args&: notError,
594 args: "Awaited async operand is in error state");
595 }
596
597 // Inside the coroutine we convert await operation into coroutine suspension
598 // point, and resume execution asynchronously.
599 if (isInCoroutine) {
600 CoroMachinery &coro = funcCoro->getSecond();
601 Block *suspended = op->getBlock();
602
603 ImplicitLocOpBuilder builder(loc, rewriter);
604 MLIRContext *ctx = op->getContext();
605
606 // Save the coroutine state and resume on a runtime managed thread when
607 // the operand becomes available.
608 auto coroSaveOp =
609 builder.create<CoroSaveOp>(args: CoroStateType::get(ctx), args&: coro.coroHandle);
610 builder.create<RuntimeAwaitAndResumeOp>(args&: operand, args&: coro.coroHandle);
611
612 // Split the entry block before the await operation.
613 Block *resume = rewriter.splitBlock(block: suspended, before: Block::iterator(op));
614
615 // Add async.coro.suspend as a suspended block terminator.
616 builder.setInsertionPointToEnd(suspended);
617 builder.create<CoroSuspendOp>(args: coroSaveOp.getState(), args&: coro.suspend, args&: resume,
618 args&: coro.cleanupForDestroy);
619
620 // Split the resume block into error checking and continuation.
621 Block *continuation = rewriter.splitBlock(block: resume, before: Block::iterator(op));
622
623 // Check if the awaited value is in the error state.
624 builder.setInsertionPointToStart(resume);
625 auto isError = builder.create<RuntimeIsErrorOp>(location: loc, args&: i1, args&: operand);
626 builder.create<cf::CondBranchOp>(args&: isError,
627 /*trueDest=*/args: setupSetErrorBlock(coro),
628 /*trueArgs=*/args: ArrayRef<Value>(),
629 /*falseDest=*/args&: continuation,
630 /*falseArgs=*/args: ArrayRef<Value>());
631
632 // Make sure that replacement value will be constructed in the
633 // continuation block.
634 rewriter.setInsertionPointToStart(continuation);
635 }
636
637 // Erase or replace the await operation with the new value.
638 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
639 rewriter.replaceOp(op, replaceWith);
640 else
641 rewriter.eraseOp(op);
642
643 return success();
644 }
645
646 virtual Value getReplacementValue(AwaitType op, Value operand,
647 ConversionPatternRewriter &rewriter) const {
648 return Value();
649 }
650
651private:
652 FuncCoroMapPtr coros;
653 bool shouldLowerBlockingWait;
654};
655
656/// Lowering for `async.await` with a token operand.
657class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
658 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
659
660public:
661 using Base::Base;
662};
663
664/// Lowering for `async.await` with a value operand.
665class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
666 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
667
668public:
669 using Base::Base;
670
671 Value
672 getReplacementValue(AwaitOp op, Value operand,
673 ConversionPatternRewriter &rewriter) const override {
674 // Load from the async value storage.
675 auto valueType = cast<ValueType>(Val: operand.getType()).getValueType();
676 return rewriter.create<RuntimeLoadOp>(location: op->getLoc(), args&: valueType, args&: operand);
677 }
678};
679
680/// Lowering for `async.await_all` operation.
681class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
682 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
683
684public:
685 using Base::Base;
686};
687
688} // namespace
689
690//===----------------------------------------------------------------------===//
691// Convert async.yield operation to async.runtime operations.
692//===----------------------------------------------------------------------===//
693
694class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
695public:
696 YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
697 : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
698
699 LogicalResult
700 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
701 ConversionPatternRewriter &rewriter) const override {
702 // Check if yield operation is inside the async coroutine function.
703 auto func = op->template getParentOfType<func::FuncOp>();
704 auto funcCoro = coros->find(Val: func);
705 if (funcCoro == coros->end())
706 return rewriter.notifyMatchFailure(
707 arg&: op, msg: "operation is not inside the async coroutine function");
708
709 Location loc = op->getLoc();
710 const CoroMachinery &coro = funcCoro->getSecond();
711
712 // Store yielded values into the async values storage and switch async
713 // values state to available.
714 for (auto tuple : llvm::zip(t: adaptor.getOperands(), u: coro.returnValues)) {
715 Value yieldValue = std::get<0>(t&: tuple);
716 Value asyncValue = std::get<1>(t&: tuple);
717 rewriter.create<RuntimeStoreOp>(location: loc, args&: yieldValue, args&: asyncValue);
718 rewriter.create<RuntimeSetAvailableOp>(location: loc, args&: asyncValue);
719 }
720
721 if (coro.asyncToken)
722 // Switch the coroutine completion token to available state.
723 rewriter.create<RuntimeSetAvailableOp>(location: loc, args: *coro.asyncToken);
724
725 rewriter.create<cf::BranchOp>(location: loc, args: coro.cleanup);
726 rewriter.eraseOp(op);
727
728 return success();
729 }
730
731private:
732 FuncCoroMapPtr coros;
733};
734
735//===----------------------------------------------------------------------===//
736// Convert cf.assert operation to cf.cond_br into `set_error` block.
737//===----------------------------------------------------------------------===//
738
739class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
740public:
741 AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
742 : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
743
744 LogicalResult
745 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter) const override {
747 // Check if assert operation is inside the async coroutine function.
748 auto func = op->template getParentOfType<func::FuncOp>();
749 auto funcCoro = coros->find(Val: func);
750 if (funcCoro == coros->end())
751 return rewriter.notifyMatchFailure(
752 arg&: op, msg: "operation is not inside the async coroutine function");
753
754 Location loc = op->getLoc();
755 CoroMachinery &coro = funcCoro->getSecond();
756
757 Block *cont = rewriter.splitBlock(block: op->getBlock(), before: Block::iterator(op));
758 rewriter.setInsertionPointToEnd(cont->getPrevNode());
759 rewriter.create<cf::CondBranchOp>(location: loc, args: adaptor.getArg(),
760 /*trueDest=*/args&: cont,
761 /*trueArgs=*/args: ArrayRef<Value>(),
762 /*falseDest=*/args: setupSetErrorBlock(coro),
763 /*falseArgs=*/args: ArrayRef<Value>());
764 rewriter.eraseOp(op);
765
766 return success();
767 }
768
769private:
770 FuncCoroMapPtr coros;
771};
772
773//===----------------------------------------------------------------------===//
774void AsyncToAsyncRuntimePass::runOnOperation() {
775 ModuleOp module = getOperation();
776 SymbolTable symbolTable(module);
777
778 // Functions with coroutine CFG setups, which are results of outlining
779 // `async.execute` body regions
780 FuncCoroMapPtr coros =
781 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
782
783 module.walk(callback: [&](ExecuteOp execute) {
784 coros->insert(KV: outlineExecuteOp(symbolTable, execute));
785 });
786
787 LLVM_DEBUG({
788 llvm::dbgs() << "Outlined " << coros->size()
789 << " functions built from async.execute operations\n";
790 });
791
792 // Returns true if operation is inside the coroutine.
793 auto isInCoroutine = [&](Operation *op) -> bool {
794 auto parentFunc = op->getParentOfType<func::FuncOp>();
795 return coros->contains(Val: parentFunc);
796 };
797
798 // Lower async operations to async.runtime operations.
799 MLIRContext *ctx = module->getContext();
800 RewritePatternSet asyncPatterns(ctx);
801
802 // Conversion to async runtime augments original CFG with the coroutine CFG,
803 // and we have to make sure that structured control flow operations with async
804 // operations in nested regions will be converted to branch-based control flow
805 // before we add the coroutine basic blocks.
806 populateSCFToControlFlowConversionPatterns(patterns&: asyncPatterns);
807
808 // Async lowering does not use type converter because it must preserve all
809 // types for async.runtime operations.
810 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(arg&: ctx);
811
812 asyncPatterns
813 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
814 arg&: ctx, args&: coros, /*should_lower_blocking_wait=*/args: true);
815
816 // Lower assertions to conditional branches into error blocks.
817 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(arg&: ctx, args&: coros);
818
819 // All high level async operations must be lowered to the runtime operations.
820 ConversionTarget runtimeTarget(*ctx);
821 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
822 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
823 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
824
825 // Decide if structured control flow has to be lowered to branch-based CFG.
826 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>(callback: [&](Operation *op) {
827 auto walkResult = op->walk(callback: [&](Operation *nested) {
828 bool isAsync = isa<async::AsyncDialect>(Val: nested->getDialect());
829 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
830 : WalkResult::advance();
831 });
832 return !walkResult.wasInterrupted();
833 });
834 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
835 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
836
837 // Assertions must be converted to runtime errors inside async functions.
838 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
839 callback: [&](cf::AssertOp op) -> bool {
840 auto func = op->getParentOfType<func::FuncOp>();
841 return !coros->contains(Val: func);
842 });
843
844 if (failed(Result: applyPartialConversion(op: module, target: runtimeTarget,
845 patterns: std::move(asyncPatterns)))) {
846 signalPassFailure();
847 return;
848 }
849}
850
851//===----------------------------------------------------------------------===//
852void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
853 RewritePatternSet &patterns, ConversionTarget &target) {
854 // Functions with coroutine CFG setups, which are results of converting
855 // async.func.
856 FuncCoroMapPtr coros =
857 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
858 MLIRContext *ctx = patterns.getContext();
859 // Lower async.func to func.func with coroutine cfg.
860 patterns.add<AsyncCallOpLowering>(arg&: ctx);
861 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(arg&: ctx, args&: coros);
862
863 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
864 arg&: ctx, args&: coros, /*should_lower_blocking_wait=*/args: false);
865 patterns.add<YieldOpLowering, AssertOpLowering>(arg&: ctx, args&: coros);
866
867 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
868 callback: [coros](Operation *op) {
869 auto exec = op->getParentOfType<ExecuteOp>();
870 auto func = op->getParentOfType<func::FuncOp>();
871 return exec || !coros->contains(Val: func);
872 });
873}
874
875void AsyncFuncToAsyncRuntimePass::runOnOperation() {
876 ModuleOp module = getOperation();
877
878 // Lower async operations to async.runtime operations.
879 MLIRContext *ctx = module->getContext();
880 RewritePatternSet asyncPatterns(ctx);
881 ConversionTarget runtimeTarget(*ctx);
882
883 // Lower async.func to func.func with coroutine cfg.
884 populateAsyncFuncToAsyncRuntimeConversionPatterns(patterns&: asyncPatterns,
885 target&: runtimeTarget);
886
887 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
888 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
889
890 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
891 cf::BranchOp, cf::CondBranchOp>();
892
893 if (failed(Result: applyPartialConversion(op: module, target: runtimeTarget,
894 patterns: std::move(asyncPatterns)))) {
895 signalPassFailure();
896 return;
897 }
898}
899

source code of mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp