1//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13#include "mlir/Conversion/LLVMCommon/Pattern.h"
14#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Async/IR/Async.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
19#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/IR/ImplicitLocOpBuilder.h"
22#include "mlir/IR/TypeUtilities.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32#define DEBUG_TYPE "convert-async-to-llvm"
33
34using namespace mlir;
35using namespace mlir::async;
36
37//===----------------------------------------------------------------------===//
38// Async Runtime C API declaration.
39//===----------------------------------------------------------------------===//
40
41static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
42static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
43static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
44static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
45static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
46static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
47static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
48static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
49static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
50static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
51static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
52static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
53static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
54static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
55static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
56static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
57static constexpr const char *kGetValueStorage =
58 "mlirAsyncRuntimeGetValueStorage";
59static constexpr const char *kAddTokenToGroup =
60 "mlirAsyncRuntimeAddTokenToGroup";
61static constexpr const char *kAwaitTokenAndExecute =
62 "mlirAsyncRuntimeAwaitTokenAndExecute";
63static constexpr const char *kAwaitValueAndExecute =
64 "mlirAsyncRuntimeAwaitValueAndExecute";
65static constexpr const char *kAwaitAllAndExecute =
66 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
67static constexpr const char *kGetNumWorkerThreads =
68 "mlirAsyncRuntimGetNumWorkerThreads";
69
70namespace {
71/// Async Runtime API function types.
72///
73/// Because we can't create API function signature for type parametrized
74/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
75/// lowering all async data types become opaque pointers at runtime.
76struct AsyncAPI {
77 // All async types are lowered to opaque LLVM pointers at runtime.
78 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
79 return LLVM::LLVMPointerType::get(context: ctx);
80 }
81
82 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
83 return LLVM::LLVMTokenType::get(ctx);
84 }
85
86 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
87 auto ref = opaquePointerType(ctx);
88 auto count = IntegerType::get(context: ctx, width: 64);
89 return FunctionType::get(context: ctx, inputs: {ref, count}, results: {});
90 }
91
92 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
93 return FunctionType::get(context: ctx, inputs: {}, results: {TokenType::get(ctx)});
94 }
95
96 static FunctionType createValueFunctionType(MLIRContext *ctx) {
97 auto i64 = IntegerType::get(context: ctx, width: 64);
98 auto value = opaquePointerType(ctx);
99 return FunctionType::get(context: ctx, inputs: {i64}, results: {value});
100 }
101
102 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
103 auto i64 = IntegerType::get(context: ctx, width: 64);
104 return FunctionType::get(context: ctx, inputs: {i64}, results: {GroupType::get(ctx)});
105 }
106
107 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
108 auto ptrType = opaquePointerType(ctx);
109 return FunctionType::get(context: ctx, inputs: {ptrType}, results: {ptrType});
110 }
111
112 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
113 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx)}, results: {});
114 }
115
116 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
117 auto value = opaquePointerType(ctx);
118 return FunctionType::get(context: ctx, inputs: {value}, results: {});
119 }
120
121 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
122 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx)}, results: {});
123 }
124
125 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
126 auto value = opaquePointerType(ctx);
127 return FunctionType::get(context: ctx, inputs: {value}, results: {});
128 }
129
130 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
131 auto i1 = IntegerType::get(context: ctx, width: 1);
132 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx)}, results: {i1});
133 }
134
135 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
136 auto value = opaquePointerType(ctx);
137 auto i1 = IntegerType::get(context: ctx, width: 1);
138 return FunctionType::get(context: ctx, inputs: {value}, results: {i1});
139 }
140
141 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
142 auto i1 = IntegerType::get(context: ctx, width: 1);
143 return FunctionType::get(context: ctx, inputs: {GroupType::get(ctx)}, results: {i1});
144 }
145
146 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
147 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx)}, results: {});
148 }
149
150 static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
151 auto value = opaquePointerType(ctx);
152 return FunctionType::get(context: ctx, inputs: {value}, results: {});
153 }
154
155 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
156 return FunctionType::get(context: ctx, inputs: {GroupType::get(ctx)}, results: {});
157 }
158
159 static FunctionType executeFunctionType(MLIRContext *ctx) {
160 auto ptrType = opaquePointerType(ctx);
161 return FunctionType::get(context: ctx, inputs: {ptrType, ptrType}, results: {});
162 }
163
164 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
165 auto i64 = IntegerType::get(context: ctx, width: 64);
166 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx), GroupType::get(ctx)},
167 results: {i64});
168 }
169
170 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
171 auto ptrType = opaquePointerType(ctx);
172 return FunctionType::get(context: ctx, inputs: {TokenType::get(ctx), ptrType, ptrType}, results: {});
173 }
174
175 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
176 auto ptrType = opaquePointerType(ctx);
177 return FunctionType::get(context: ctx, inputs: {ptrType, ptrType, ptrType}, results: {});
178 }
179
180 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
181 auto ptrType = opaquePointerType(ctx);
182 return FunctionType::get(context: ctx, inputs: {GroupType::get(ctx), ptrType, ptrType}, results: {});
183 }
184
185 static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
186 return FunctionType::get(context: ctx, inputs: {}, results: {IndexType::get(context: ctx)});
187 }
188
189 // Auxiliary coroutine resume intrinsic wrapper.
190 static Type resumeFunctionType(MLIRContext *ctx) {
191 auto voidTy = LLVM::LLVMVoidType::get(ctx);
192 auto ptrType = opaquePointerType(ctx);
193 return LLVM::LLVMFunctionType::get(result: voidTy, arguments: {ptrType}, isVarArg: false);
194 }
195};
196} // namespace
197
198/// Adds Async Runtime C API declarations to the module.
199static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
200 auto builder =
201 ImplicitLocOpBuilder::atBlockEnd(loc: module.getLoc(), block: module.getBody());
202
203 auto addFuncDecl = [&](StringRef name, FunctionType type) {
204 if (module.lookupSymbol(name))
205 return;
206 builder.create<func::FuncOp>(args&: name, args&: type).setPrivate();
207 };
208
209 MLIRContext *ctx = module.getContext();
210 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
211 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
213 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
214 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
215 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
216 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
217 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
218 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
219 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
220 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
221 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
222 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
223 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
224 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
225 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
226 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
227 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
228 addFuncDecl(kAwaitTokenAndExecute,
229 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
230 addFuncDecl(kAwaitValueAndExecute,
231 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
232 addFuncDecl(kAwaitAllAndExecute,
233 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
234 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
235}
236
237//===----------------------------------------------------------------------===//
238// Coroutine resume function wrapper.
239//===----------------------------------------------------------------------===//
240
241static constexpr const char *kResume = "__resume";
242
243/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
244/// intrinsics. We need this function to be able to pass it to the async
245/// runtime execute API.
246static void addResumeFunction(ModuleOp module) {
247 if (module.lookupSymbol(name: kResume))
248 return;
249
250 MLIRContext *ctx = module.getContext();
251 auto loc = module.getLoc();
252 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block: module.getBody());
253
254 auto voidTy = LLVM::LLVMVoidType::get(ctx);
255 Type ptrType = AsyncAPI::opaquePointerType(ctx);
256
257 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
258 args: kResume, args: LLVM::LLVMFunctionType::get(result: voidTy, arguments: {ptrType}));
259 resumeOp.setPrivate();
260
261 auto *block = resumeOp.addEntryBlock(builder&: moduleBuilder);
262 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
263
264 blockBuilder.create<LLVM::CoroResumeOp>(args: resumeOp.getArgument(idx: 0));
265 blockBuilder.create<LLVM::ReturnOp>(args: ValueRange());
266}
267
268//===----------------------------------------------------------------------===//
269// Convert Async dialect types to LLVM types.
270//===----------------------------------------------------------------------===//
271
272namespace {
273/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
274/// their runtime type (opaque pointers) and does not convert any other types.
275class AsyncRuntimeTypeConverter : public TypeConverter {
276public:
277 AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
278 addConversion(callback: [](Type type) { return type; });
279 addConversion(callback: [](Type type) { return convertAsyncTypes(type); });
280
281 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
282 // in patterns for other dialects.
283 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
284 ValueRange inputs, Location loc) -> Value {
285 auto cast = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
286 return cast.getResult(i: 0);
287 };
288
289 addSourceMaterialization(callback&: addUnrealizedCast);
290 addTargetMaterialization(callback&: addUnrealizedCast);
291 }
292
293 static std::optional<Type> convertAsyncTypes(Type type) {
294 if (isa<TokenType, GroupType, ValueType>(Val: type))
295 return AsyncAPI::opaquePointerType(ctx: type.getContext());
296
297 if (isa<CoroIdType, CoroStateType>(Val: type))
298 return AsyncAPI::tokenType(ctx: type.getContext());
299 if (isa<CoroHandleType>(Val: type))
300 return AsyncAPI::opaquePointerType(ctx: type.getContext());
301
302 return std::nullopt;
303 }
304};
305
306/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
307/// as type converter. Allows access to it via the 'getTypeConverter'
308/// convenience method.
309template <typename SourceOp>
310class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
311
312 using Base = OpConversionPattern<SourceOp>;
313
314public:
315 AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
316 MLIRContext *context)
317 : Base(typeConverter, context) {}
318
319 /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
320 const AsyncRuntimeTypeConverter *getTypeConverter() const {
321 return static_cast<const AsyncRuntimeTypeConverter *>(
322 Base::getTypeConverter());
323 }
324};
325
326} // namespace
327
328//===----------------------------------------------------------------------===//
329// Convert async.coro.id to @llvm.coro.id intrinsic.
330//===----------------------------------------------------------------------===//
331
332namespace {
333class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
334public:
335 using AsyncOpConversionPattern::AsyncOpConversionPattern;
336
337 LogicalResult
338 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto token = AsyncAPI::tokenType(ctx: op->getContext());
341 auto ptrType = AsyncAPI::opaquePointerType(ctx: op->getContext());
342 auto loc = op->getLoc();
343
344 // Constants for initializing coroutine frame.
345 auto constZero =
346 rewriter.create<LLVM::ConstantOp>(location: loc, args: rewriter.getI32Type(), args: 0);
347 auto nullPtr = rewriter.create<LLVM::ZeroOp>(location: loc, args&: ptrType);
348
349 // Get coroutine id: @llvm.coro.id.
350 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
351 op, args&: token, args: ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
352
353 return success();
354 }
355};
356} // namespace
357
358//===----------------------------------------------------------------------===//
359// Convert async.coro.begin to @llvm.coro.begin intrinsic.
360//===----------------------------------------------------------------------===//
361
362namespace {
363class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
364public:
365 using AsyncOpConversionPattern::AsyncOpConversionPattern;
366
367 LogicalResult
368 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter) const override {
370 auto ptrType = AsyncAPI::opaquePointerType(ctx: op->getContext());
371 auto loc = op->getLoc();
372
373 // Get coroutine frame size: @llvm.coro.size.i64.
374 Value coroSize =
375 rewriter.create<LLVM::CoroSizeOp>(location: loc, args: rewriter.getI64Type());
376 // Get coroutine frame alignment: @llvm.coro.align.i64.
377 Value coroAlign =
378 rewriter.create<LLVM::CoroAlignOp>(location: loc, args: rewriter.getI64Type());
379
380 // Round up the size to be multiple of the alignment. Since aligned_alloc
381 // requires the size parameter be an integral multiple of the alignment
382 // parameter.
383 auto makeConstant = [&](uint64_t c) {
384 return rewriter.create<LLVM::ConstantOp>(location: op->getLoc(),
385 args: rewriter.getI64Type(), args&: c);
386 };
387 coroSize = rewriter.create<LLVM::AddOp>(location: op->getLoc(), args&: coroSize, args&: coroAlign);
388 coroSize =
389 rewriter.create<LLVM::SubOp>(location: op->getLoc(), args&: coroSize, args: makeConstant(1));
390 Value negCoroAlign =
391 rewriter.create<LLVM::SubOp>(location: op->getLoc(), args: makeConstant(0), args&: coroAlign);
392 coroSize =
393 rewriter.create<LLVM::AndOp>(location: op->getLoc(), args&: coroSize, args&: negCoroAlign);
394
395 // Allocate memory for the coroutine frame.
396 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
397 b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>(), indexType: rewriter.getI64Type());
398 if (failed(Result: allocFuncOp))
399 return failure();
400 auto coroAlloc = rewriter.create<LLVM::CallOp>(
401 location: loc, args&: allocFuncOp.value(), args: ValueRange{coroAlign, coroSize});
402
403 // Begin a coroutine: @llvm.coro.begin.
404 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
405 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
406 op, args&: ptrType, args: ValueRange({coroId, coroAlloc.getResult()}));
407
408 return success();
409 }
410};
411} // namespace
412
413//===----------------------------------------------------------------------===//
414// Convert async.coro.free to @llvm.coro.free intrinsic.
415//===----------------------------------------------------------------------===//
416
417namespace {
418class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
419public:
420 using AsyncOpConversionPattern::AsyncOpConversionPattern;
421
422 LogicalResult
423 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 auto ptrType = AsyncAPI::opaquePointerType(ctx: op->getContext());
426 auto loc = op->getLoc();
427
428 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
429 auto coroMem =
430 rewriter.create<LLVM::CoroFreeOp>(location: loc, args&: ptrType, args: adaptor.getOperands());
431
432 // Free the memory.
433 auto freeFuncOp =
434 LLVM::lookupOrCreateFreeFn(b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>());
435 if (failed(Result: freeFuncOp))
436 return failure();
437 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, args&: freeFuncOp.value(),
438 args: ValueRange(coroMem.getResult()));
439
440 return success();
441 }
442};
443} // namespace
444
445//===----------------------------------------------------------------------===//
446// Convert async.coro.end to @llvm.coro.end intrinsic.
447//===----------------------------------------------------------------------===//
448
449namespace {
450class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
451public:
452 using OpConversionPattern::OpConversionPattern;
453
454 LogicalResult
455 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter) const override {
457 // We are not in the block that is part of the unwind sequence.
458 auto constFalse = rewriter.create<LLVM::ConstantOp>(
459 location: op->getLoc(), args: rewriter.getI1Type(), args: rewriter.getBoolAttr(value: false));
460 auto noneToken = rewriter.create<LLVM::NoneTokenOp>(location: op->getLoc());
461
462 // Mark the end of a coroutine: @llvm.coro.end.
463 auto coroHdl = adaptor.getHandle();
464 rewriter.create<LLVM::CoroEndOp>(
465 location: op->getLoc(), args: rewriter.getI1Type(),
466 args: ValueRange({coroHdl, constFalse, noneToken}));
467 rewriter.eraseOp(op);
468
469 return success();
470 }
471};
472} // namespace
473
474//===----------------------------------------------------------------------===//
475// Convert async.coro.save to @llvm.coro.save intrinsic.
476//===----------------------------------------------------------------------===//
477
478namespace {
479class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
480public:
481 using OpConversionPattern::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486 // Save the coroutine state: @llvm.coro.save
487 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
488 op, args: AsyncAPI::tokenType(ctx: op->getContext()), args: adaptor.getOperands());
489
490 return success();
491 }
492};
493} // namespace
494
495//===----------------------------------------------------------------------===//
496// Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
497//===----------------------------------------------------------------------===//
498
499namespace {
500
501/// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
502/// branch to the appropriate block based on the return code.
503///
504/// Before:
505///
506/// ^suspended:
507/// "opBefore"(...)
508/// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
509/// ^resume:
510/// "op"(...)
511/// ^cleanup: ...
512/// ^suspend: ...
513///
514/// After:
515///
516/// ^suspended:
517/// "opBefore"(...)
518/// %suspend = llmv.intr.coro.suspend ...
519/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
520/// ^resume:
521/// "op"(...)
522/// ^cleanup: ...
523/// ^suspend: ...
524///
525class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
526public:
527 using OpConversionPattern::OpConversionPattern;
528
529 LogicalResult
530 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
531 ConversionPatternRewriter &rewriter) const override {
532 auto i8 = rewriter.getIntegerType(width: 8);
533 auto i32 = rewriter.getI32Type();
534 auto loc = op->getLoc();
535
536 // This is not a final suspension point.
537 auto constFalse = rewriter.create<LLVM::ConstantOp>(
538 location: loc, args: rewriter.getI1Type(), args: rewriter.getBoolAttr(value: false));
539
540 // Suspend a coroutine: @llvm.coro.suspend
541 auto coroState = adaptor.getState();
542 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
543 location: loc, args&: i8, args: ValueRange({coroState, constFalse}));
544
545 // Cast return code to i32.
546
547 // After a suspension point decide if we should branch into resume, cleanup
548 // or suspend block of the coroutine (see @llvm.coro.suspend return code
549 // documentation).
550 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
551 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
552 op.getCleanupDest()};
553 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
554 op, args: rewriter.create<LLVM::SExtOp>(location: loc, args&: i32, args: coroSuspend.getResult()),
555 /*defaultDestination=*/args: op.getSuspendDest(),
556 /*defaultOperands=*/args: ValueRange(),
557 /*caseValues=*/args&: caseValues,
558 /*caseDestinations=*/args&: caseDest,
559 /*caseOperands=*/args: ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
560 /*branchWeights=*/args: ArrayRef<int32_t>());
561
562 return success();
563 }
564};
565} // namespace
566
567//===----------------------------------------------------------------------===//
568// Convert async.runtime.create to the corresponding runtime API call.
569//
570// To allocate storage for the async values we use getelementptr trick:
571// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
572//===----------------------------------------------------------------------===//
573
574namespace {
575class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
576public:
577 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
578
579 LogicalResult
580 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter) const override {
582 const TypeConverter *converter = getTypeConverter();
583 Type resultType = op->getResultTypes()[0];
584
585 // Tokens creation maps to a simple function call.
586 if (isa<TokenType>(Val: resultType)) {
587 rewriter.replaceOpWithNewOp<func::CallOp>(
588 op, args: kCreateToken, args: converter->convertType(t: resultType));
589 return success();
590 }
591
592 // To create a value we need to compute the storage requirement.
593 if (auto value = dyn_cast<ValueType>(Val&: resultType)) {
594 // Returns the size requirements for the async value storage.
595 auto sizeOf = [&](ValueType valueType) -> Value {
596 auto loc = op->getLoc();
597 auto i64 = rewriter.getI64Type();
598
599 auto storedType = converter->convertType(t: valueType.getValueType());
600 auto storagePtrType =
601 AsyncAPI::opaquePointerType(ctx: rewriter.getContext());
602
603 // %Size = getelementptr %T* null, int 1
604 // %SizeI = ptrtoint %T* %Size to i64
605 auto nullPtr = rewriter.create<LLVM::ZeroOp>(location: loc, args&: storagePtrType);
606 auto gep =
607 rewriter.create<LLVM::GEPOp>(location: loc, args&: storagePtrType, args&: storedType,
608 args&: nullPtr, args: ArrayRef<LLVM::GEPArg>{1});
609 return rewriter.create<LLVM::PtrToIntOp>(location: loc, args&: i64, args&: gep);
610 };
611
612 rewriter.replaceOpWithNewOp<func::CallOp>(op, args: kCreateValue, args&: resultType,
613 args: sizeOf(value));
614
615 return success();
616 }
617
618 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported async type");
619 }
620};
621} // namespace
622
623//===----------------------------------------------------------------------===//
624// Convert async.runtime.create_group to the corresponding runtime API call.
625//===----------------------------------------------------------------------===//
626
627namespace {
628class RuntimeCreateGroupOpLowering
629 : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
630public:
631 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
632
633 LogicalResult
634 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
635 ConversionPatternRewriter &rewriter) const override {
636 const TypeConverter *converter = getTypeConverter();
637 Type resultType = op.getResult().getType();
638
639 rewriter.replaceOpWithNewOp<func::CallOp>(
640 op, args: kCreateGroup, args: converter->convertType(t: resultType),
641 args: adaptor.getOperands());
642 return success();
643 }
644};
645} // namespace
646
647//===----------------------------------------------------------------------===//
648// Convert async.runtime.set_available to the corresponding runtime API call.
649//===----------------------------------------------------------------------===//
650
651namespace {
652class RuntimeSetAvailableOpLowering
653 : public OpConversionPattern<RuntimeSetAvailableOp> {
654public:
655 using OpConversionPattern::OpConversionPattern;
656
657 LogicalResult
658 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter) const override {
660 StringRef apiFuncName =
661 TypeSwitch<Type, StringRef>(op.getOperand().getType())
662 .Case<TokenType>(caseFn: [](Type) { return kEmplaceToken; })
663 .Case<ValueType>(caseFn: [](Type) { return kEmplaceValue; });
664
665 rewriter.replaceOpWithNewOp<func::CallOp>(op, args&: apiFuncName, args: TypeRange(),
666 args: adaptor.getOperands());
667
668 return success();
669 }
670};
671} // namespace
672
673//===----------------------------------------------------------------------===//
674// Convert async.runtime.set_error to the corresponding runtime API call.
675//===----------------------------------------------------------------------===//
676
677namespace {
678class RuntimeSetErrorOpLowering
679 : public OpConversionPattern<RuntimeSetErrorOp> {
680public:
681 using OpConversionPattern::OpConversionPattern;
682
683 LogicalResult
684 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
685 ConversionPatternRewriter &rewriter) const override {
686 StringRef apiFuncName =
687 TypeSwitch<Type, StringRef>(op.getOperand().getType())
688 .Case<TokenType>(caseFn: [](Type) { return kSetTokenError; })
689 .Case<ValueType>(caseFn: [](Type) { return kSetValueError; });
690
691 rewriter.replaceOpWithNewOp<func::CallOp>(op, args&: apiFuncName, args: TypeRange(),
692 args: adaptor.getOperands());
693
694 return success();
695 }
696};
697} // namespace
698
699//===----------------------------------------------------------------------===//
700// Convert async.runtime.is_error to the corresponding runtime API call.
701//===----------------------------------------------------------------------===//
702
703namespace {
704class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
705public:
706 using OpConversionPattern::OpConversionPattern;
707
708 LogicalResult
709 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 StringRef apiFuncName =
712 TypeSwitch<Type, StringRef>(op.getOperand().getType())
713 .Case<TokenType>(caseFn: [](Type) { return kIsTokenError; })
714 .Case<GroupType>(caseFn: [](Type) { return kIsGroupError; })
715 .Case<ValueType>(caseFn: [](Type) { return kIsValueError; });
716
717 rewriter.replaceOpWithNewOp<func::CallOp>(
718 op, args&: apiFuncName, args: rewriter.getI1Type(), args: adaptor.getOperands());
719 return success();
720 }
721};
722} // namespace
723
724//===----------------------------------------------------------------------===//
725// Convert async.runtime.await to the corresponding runtime API call.
726//===----------------------------------------------------------------------===//
727
728namespace {
729class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
730public:
731 using OpConversionPattern::OpConversionPattern;
732
733 LogicalResult
734 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
735 ConversionPatternRewriter &rewriter) const override {
736 StringRef apiFuncName =
737 TypeSwitch<Type, StringRef>(op.getOperand().getType())
738 .Case<TokenType>(caseFn: [](Type) { return kAwaitToken; })
739 .Case<ValueType>(caseFn: [](Type) { return kAwaitValue; })
740 .Case<GroupType>(caseFn: [](Type) { return kAwaitGroup; });
741
742 rewriter.create<func::CallOp>(location: op->getLoc(), args&: apiFuncName, args: TypeRange(),
743 args: adaptor.getOperands());
744 rewriter.eraseOp(op);
745
746 return success();
747 }
748};
749} // namespace
750
751//===----------------------------------------------------------------------===//
752// Convert async.runtime.await_and_resume to the corresponding runtime API call.
753//===----------------------------------------------------------------------===//
754
755namespace {
756class RuntimeAwaitAndResumeOpLowering
757 : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
758public:
759 using AsyncOpConversionPattern::AsyncOpConversionPattern;
760
761 LogicalResult
762 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter) const override {
764 StringRef apiFuncName =
765 TypeSwitch<Type, StringRef>(op.getOperand().getType())
766 .Case<TokenType>(caseFn: [](Type) { return kAwaitTokenAndExecute; })
767 .Case<ValueType>(caseFn: [](Type) { return kAwaitValueAndExecute; })
768 .Case<GroupType>(caseFn: [](Type) { return kAwaitAllAndExecute; });
769
770 Value operand = adaptor.getOperand();
771 Value handle = adaptor.getHandle();
772
773 // A pointer to coroutine resume intrinsic wrapper.
774 addResumeFunction(module: op->getParentOfType<ModuleOp>());
775 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
776 location: op->getLoc(), args: AsyncAPI::opaquePointerType(ctx: rewriter.getContext()),
777 args: kResume);
778
779 rewriter.create<func::CallOp>(
780 location: op->getLoc(), args&: apiFuncName, args: TypeRange(),
781 args: ValueRange({operand, handle, resumePtr.getRes()}));
782 rewriter.eraseOp(op);
783
784 return success();
785 }
786};
787} // namespace
788
789//===----------------------------------------------------------------------===//
790// Convert async.runtime.resume to the corresponding runtime API call.
791//===----------------------------------------------------------------------===//
792
793namespace {
794class RuntimeResumeOpLowering
795 : public AsyncOpConversionPattern<RuntimeResumeOp> {
796public:
797 using AsyncOpConversionPattern::AsyncOpConversionPattern;
798
799 LogicalResult
800 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
801 ConversionPatternRewriter &rewriter) const override {
802 // A pointer to coroutine resume intrinsic wrapper.
803 addResumeFunction(module: op->getParentOfType<ModuleOp>());
804 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
805 location: op->getLoc(), args: AsyncAPI::opaquePointerType(ctx: rewriter.getContext()),
806 args: kResume);
807
808 // Call async runtime API to execute a coroutine in the managed thread.
809 auto coroHdl = adaptor.getHandle();
810 rewriter.replaceOpWithNewOp<func::CallOp>(
811 op, args: TypeRange(), args: kExecute, args: ValueRange({coroHdl, resumePtr.getRes()}));
812
813 return success();
814 }
815};
816} // namespace
817
818//===----------------------------------------------------------------------===//
819// Convert async.runtime.store to the corresponding runtime API call.
820//===----------------------------------------------------------------------===//
821
822namespace {
823class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
824public:
825 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
826
827 LogicalResult
828 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
829 ConversionPatternRewriter &rewriter) const override {
830 Location loc = op->getLoc();
831
832 // Get a pointer to the async value storage from the runtime.
833 auto ptrType = AsyncAPI::opaquePointerType(ctx: rewriter.getContext());
834 auto storage = adaptor.getStorage();
835 auto storagePtr = rewriter.create<func::CallOp>(
836 location: loc, args: kGetValueStorage, args: TypeRange(ptrType), args&: storage);
837
838 // Cast from i8* to the LLVM pointer type.
839 auto valueType = op.getValue().getType();
840 auto llvmValueType = getTypeConverter()->convertType(t: valueType);
841 if (!llvmValueType)
842 return rewriter.notifyMatchFailure(
843 arg&: op, msg: "failed to convert stored value type to LLVM type");
844
845 Value castedStoragePtr = storagePtr.getResult(i: 0);
846 // Store the yielded value into the async value storage.
847 auto value = adaptor.getValue();
848 rewriter.create<LLVM::StoreOp>(location: loc, args&: value, args&: castedStoragePtr);
849
850 // Erase the original runtime store operation.
851 rewriter.eraseOp(op);
852
853 return success();
854 }
855};
856} // namespace
857
858//===----------------------------------------------------------------------===//
859// Convert async.runtime.load to the corresponding runtime API call.
860//===----------------------------------------------------------------------===//
861
862namespace {
863class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
864public:
865 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
866
867 LogicalResult
868 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
869 ConversionPatternRewriter &rewriter) const override {
870 Location loc = op->getLoc();
871
872 // Get a pointer to the async value storage from the runtime.
873 auto ptrType = AsyncAPI::opaquePointerType(ctx: rewriter.getContext());
874 auto storage = adaptor.getStorage();
875 auto storagePtr = rewriter.create<func::CallOp>(
876 location: loc, args: kGetValueStorage, args: TypeRange(ptrType), args&: storage);
877
878 // Cast from i8* to the LLVM pointer type.
879 auto valueType = op.getResult().getType();
880 auto llvmValueType = getTypeConverter()->convertType(t: valueType);
881 if (!llvmValueType)
882 return rewriter.notifyMatchFailure(
883 arg&: op, msg: "failed to convert loaded value type to LLVM type");
884
885 Value castedStoragePtr = storagePtr.getResult(i: 0);
886
887 // Load from the casted pointer.
888 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, args&: llvmValueType,
889 args&: castedStoragePtr);
890
891 return success();
892 }
893};
894} // namespace
895
896//===----------------------------------------------------------------------===//
897// Convert async.runtime.add_to_group to the corresponding runtime API call.
898//===----------------------------------------------------------------------===//
899
900namespace {
901class RuntimeAddToGroupOpLowering
902 : public OpConversionPattern<RuntimeAddToGroupOp> {
903public:
904 using OpConversionPattern::OpConversionPattern;
905
906 LogicalResult
907 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter) const override {
909 // Currently we can only add tokens to the group.
910 if (!isa<TokenType>(Val: op.getOperand().getType()))
911 return rewriter.notifyMatchFailure(arg&: op, msg: "only token type is supported");
912
913 // Replace with a runtime API function call.
914 rewriter.replaceOpWithNewOp<func::CallOp>(
915 op, args: kAddTokenToGroup, args: rewriter.getI64Type(), args: adaptor.getOperands());
916
917 return success();
918 }
919};
920} // namespace
921
922//===----------------------------------------------------------------------===//
923// Convert async.runtime.num_worker_threads to the corresponding runtime API
924// call.
925//===----------------------------------------------------------------------===//
926
927namespace {
928class RuntimeNumWorkerThreadsOpLowering
929 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
930public:
931 using OpConversionPattern::OpConversionPattern;
932
933 LogicalResult
934 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
935 ConversionPatternRewriter &rewriter) const override {
936
937 // Replace with a runtime API function call.
938 rewriter.replaceOpWithNewOp<func::CallOp>(op, args: kGetNumWorkerThreads,
939 args: rewriter.getIndexType());
940
941 return success();
942 }
943};
944} // namespace
945
946//===----------------------------------------------------------------------===//
947// Async reference counting ops lowering (`async.runtime.add_ref` and
948// `async.runtime.drop_ref` to the corresponding API calls).
949//===----------------------------------------------------------------------===//
950
951namespace {
952template <typename RefCountingOp>
953class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
954public:
955 explicit RefCountingOpLowering(const TypeConverter &converter,
956 MLIRContext *ctx, StringRef apiFunctionName)
957 : OpConversionPattern<RefCountingOp>(converter, ctx),
958 apiFunctionName(apiFunctionName) {}
959
960 LogicalResult
961 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
962 ConversionPatternRewriter &rewriter) const override {
963 auto count = rewriter.create<arith::ConstantOp>(
964 op->getLoc(), rewriter.getI64Type(),
965 rewriter.getI64IntegerAttr(value: op.getCount()));
966
967 auto operand = adaptor.getOperand();
968 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
969 ValueRange({operand, count}));
970
971 return success();
972 }
973
974private:
975 StringRef apiFunctionName;
976};
977
978class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
979public:
980 explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
981 MLIRContext *ctx)
982 : RefCountingOpLowering(converter, ctx, kAddRef) {}
983};
984
985class RuntimeDropRefOpLowering
986 : public RefCountingOpLowering<RuntimeDropRefOp> {
987public:
988 explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
989 MLIRContext *ctx)
990 : RefCountingOpLowering(converter, ctx, kDropRef) {}
991};
992} // namespace
993
994//===----------------------------------------------------------------------===//
995// Convert return operations that return async values from async regions.
996//===----------------------------------------------------------------------===//
997
998namespace {
999class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
1000public:
1001 using OpConversionPattern::OpConversionPattern;
1002
1003 LogicalResult
1004 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1005 ConversionPatternRewriter &rewriter) const override {
1006 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, args: adaptor.getOperands());
1007 return success();
1008 }
1009};
1010} // namespace
1011
1012//===----------------------------------------------------------------------===//
1013
1014namespace {
1015struct ConvertAsyncToLLVMPass
1016 : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1017 using Base::Base;
1018
1019 void runOnOperation() override;
1020};
1021} // namespace
1022
1023void ConvertAsyncToLLVMPass::runOnOperation() {
1024 ModuleOp module = getOperation();
1025 MLIRContext *ctx = module->getContext();
1026
1027 LowerToLLVMOptions options(ctx);
1028
1029 // Add declarations for most functions required by the coroutines lowering.
1030 // We delay adding the resume function until it's needed because it currently
1031 // fails to compile unless '-O0' is specified.
1032 addAsyncRuntimeApiDeclarations(module);
1033
1034 // Lower async.runtime and async.coro operations to Async Runtime API and
1035 // LLVM coroutine intrinsics.
1036
1037 // Convert async dialect types and operations to LLVM dialect.
1038 AsyncRuntimeTypeConverter converter(options);
1039 RewritePatternSet patterns(ctx);
1040
1041 // We use conversion to LLVM type to lower async.runtime load and store
1042 // operations.
1043 LLVMTypeConverter llvmConverter(ctx, options);
1044 llvmConverter.addConversion(callback: [&](Type type) {
1045 return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1046 });
1047
1048 // Convert async types in function signatures and function calls.
1049 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1050 converter);
1051 populateCallOpTypeConversionPattern(patterns, converter);
1052
1053 // Convert return operations inside async.execute regions.
1054 patterns.add<ReturnOpOpConversion>(arg&: converter, args&: ctx);
1055
1056 // Lower async.runtime operations to the async runtime API calls.
1057 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1058 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1059 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1060 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1061 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(arg&: converter,
1062 args&: ctx);
1063
1064 // Lower async.runtime operations that rely on LLVM type converter to convert
1065 // from async value payload type to the LLVM type.
1066 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1067 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(arg&: llvmConverter);
1068
1069 // Lower async coroutine operations to LLVM coroutine intrinsics.
1070 patterns
1071 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1072 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1073 arg&: converter, args&: ctx);
1074
1075 ConversionTarget target(*ctx);
1076 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1077 UnrealizedConversionCastOp>();
1078 target.addLegalDialect<LLVM::LLVMDialect>();
1079
1080 // All operations from Async dialect must be lowered to the runtime API and
1081 // LLVM intrinsics calls.
1082 target.addIllegalDialect<AsyncDialect>();
1083
1084 // Add dynamic legality constraints to apply conversions defined above.
1085 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
1086 return converter.isSignatureLegal(ty: op.getFunctionType());
1087 });
1088 target.addDynamicallyLegalOp<func::ReturnOp>(callback: [&](func::ReturnOp op) {
1089 return converter.isLegal(range: op.getOperandTypes());
1090 });
1091 target.addDynamicallyLegalOp<func::CallOp>(callback: [&](func::CallOp op) {
1092 return converter.isSignatureLegal(ty: op.getCalleeType());
1093 });
1094
1095 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
1096 signalPassFailure();
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// Patterns for structural type conversions for the Async dialect operations.
1101//===----------------------------------------------------------------------===//
1102
1103namespace {
1104class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1105public:
1106 using OpConversionPattern::OpConversionPattern;
1107 LogicalResult
1108 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1109 ConversionPatternRewriter &rewriter) const override {
1110 ExecuteOp newOp =
1111 cast<ExecuteOp>(Val: rewriter.cloneWithoutRegions(op&: *op.getOperation()));
1112 rewriter.inlineRegionBefore(region&: op.getRegion(), parent&: newOp.getRegion(),
1113 before: newOp.getRegion().end());
1114
1115 // Set operands and update block argument and result types.
1116 newOp->setOperands(adaptor.getOperands());
1117 if (failed(Result: rewriter.convertRegionTypes(region: &newOp.getRegion(), converter: *typeConverter)))
1118 return failure();
1119 for (auto result : newOp.getResults())
1120 result.setType(typeConverter->convertType(t: result.getType()));
1121
1122 rewriter.replaceOp(op, newValues: newOp.getResults());
1123 return success();
1124 }
1125};
1126
1127// Dummy pattern to trigger the appropriate type conversion / materialization.
1128class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1129public:
1130 using OpConversionPattern::OpConversionPattern;
1131 LogicalResult
1132 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter) const override {
1134 rewriter.replaceOpWithNewOp<AwaitOp>(op, args: adaptor.getOperands().front());
1135 return success();
1136 }
1137};
1138
1139// Dummy pattern to trigger the appropriate type conversion / materialization.
1140class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1141public:
1142 using OpConversionPattern::OpConversionPattern;
1143 LogicalResult
1144 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1145 ConversionPatternRewriter &rewriter) const override {
1146 rewriter.replaceOpWithNewOp<async::YieldOp>(op, args: adaptor.getOperands());
1147 return success();
1148 }
1149};
1150} // namespace
1151
1152void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1153 TypeConverter &typeConverter, RewritePatternSet &patterns,
1154 ConversionTarget &target) {
1155 typeConverter.addConversion(callback: [&](TokenType type) { return type; });
1156 typeConverter.addConversion(callback: [&](ValueType type) {
1157 Type converted = typeConverter.convertType(t: type.getValueType());
1158 return converted ? ValueType::get(valueType: converted) : converted;
1159 });
1160
1161 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1162 arg&: typeConverter, args: patterns.getContext());
1163
1164 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1165 callback: [&](Operation *op) { return typeConverter.isLegal(op); });
1166}
1167

source code of mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp