1//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Async/IR/Async.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33#define DEBUG_TYPE "convert-async-to-llvm"
34
35using namespace mlir;
36using namespace mlir::async;
37
38//===----------------------------------------------------------------------===//
39// Async Runtime C API declaration.
40//===----------------------------------------------------------------------===//
41
42static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
44static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
47static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
49static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
50static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
51static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
52static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
54static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
57static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58static constexpr const char *kGetValueStorage =
59 "mlirAsyncRuntimeGetValueStorage";
60static constexpr const char *kAddTokenToGroup =
61 "mlirAsyncRuntimeAddTokenToGroup";
62static constexpr const char *kAwaitTokenAndExecute =
63 "mlirAsyncRuntimeAwaitTokenAndExecute";
64static constexpr const char *kAwaitValueAndExecute =
65 "mlirAsyncRuntimeAwaitValueAndExecute";
66static constexpr const char *kAwaitAllAndExecute =
67 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68static constexpr const char *kGetNumWorkerThreads =
69 "mlirAsyncRuntimGetNumWorkerThreads";
70
71namespace {
72/// Async Runtime API function types.
73///
74/// Because we can't create API function signature for type parametrized
75/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76/// lowering all async data types become opaque pointers at runtime.
77struct AsyncAPI {
78 // All async types are lowered to opaque LLVM pointers at runtime.
79 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
80 return LLVM::LLVMPointerType::get(ctx);
81 }
82
83 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
84 return LLVM::LLVMTokenType::get(ctx);
85 }
86
87 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88 auto ref = opaquePointerType(ctx);
89 auto count = IntegerType::get(ctx, 64);
90 return FunctionType::get(ctx, {ref, count}, {});
91 }
92
93 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
94 return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95 }
96
97 static FunctionType createValueFunctionType(MLIRContext *ctx) {
98 auto i64 = IntegerType::get(ctx, 64);
99 auto value = opaquePointerType(ctx);
100 return FunctionType::get(ctx, {i64}, {value});
101 }
102
103 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104 auto i64 = IntegerType::get(ctx, 64);
105 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106 }
107
108 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109 auto ptrType = opaquePointerType(ctx);
110 return FunctionType::get(ctx, {ptrType}, {ptrType});
111 }
112
113 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
114 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115 }
116
117 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118 auto value = opaquePointerType(ctx);
119 return FunctionType::get(ctx, {value}, {});
120 }
121
122 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
123 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124 }
125
126 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127 auto value = opaquePointerType(ctx);
128 return FunctionType::get(ctx, {value}, {});
129 }
130
131 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
132 auto i1 = IntegerType::get(ctx, 1);
133 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134 }
135
136 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137 auto value = opaquePointerType(ctx);
138 auto i1 = IntegerType::get(ctx, 1);
139 return FunctionType::get(ctx, {value}, {i1});
140 }
141
142 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143 auto i1 = IntegerType::get(ctx, 1);
144 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145 }
146
147 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
148 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149 }
150
151 static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152 auto value = opaquePointerType(ctx);
153 return FunctionType::get(ctx, {value}, {});
154 }
155
156 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
157 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158 }
159
160 static FunctionType executeFunctionType(MLIRContext *ctx) {
161 auto ptrType = opaquePointerType(ctx);
162 return FunctionType::get(ctx, {ptrType, ptrType}, {});
163 }
164
165 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166 auto i64 = IntegerType::get(ctx, 64);
167 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168 {i64});
169 }
170
171 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172 auto ptrType = opaquePointerType(ctx);
173 return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174 }
175
176 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177 auto ptrType = opaquePointerType(ctx);
178 return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179 }
180
181 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182 auto ptrType = opaquePointerType(ctx);
183 return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184 }
185
186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188 }
189
190 // Auxiliary coroutine resume intrinsic wrapper.
191 static Type resumeFunctionType(MLIRContext *ctx) {
192 auto voidTy = LLVM::LLVMVoidType::get(ctx);
193 auto ptrType = opaquePointerType(ctx);
194 return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
195 }
196};
197} // namespace
198
199/// Adds Async Runtime C API declarations to the module.
200static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201 auto builder =
202 ImplicitLocOpBuilder::atBlockEnd(loc: module.getLoc(), block: module.getBody());
203
204 auto addFuncDecl = [&](StringRef name, FunctionType type) {
205 if (module.lookupSymbol(name))
206 return;
207 builder.create<func::FuncOp>(name, type).setPrivate();
208 };
209
210 MLIRContext *ctx = module.getContext();
211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229 addFuncDecl(kAwaitTokenAndExecute,
230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231 addFuncDecl(kAwaitValueAndExecute,
232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233 addFuncDecl(kAwaitAllAndExecute,
234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236}
237
238//===----------------------------------------------------------------------===//
239// Coroutine resume function wrapper.
240//===----------------------------------------------------------------------===//
241
242static constexpr const char *kResume = "__resume";
243
244/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245/// intrinsics. We need this function to be able to pass it to the async
246/// runtime execute API.
247static void addResumeFunction(ModuleOp module) {
248 if (module.lookupSymbol(kResume))
249 return;
250
251 MLIRContext *ctx = module.getContext();
252 auto loc = module.getLoc();
253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: module.getBody());
254
255 auto voidTy = LLVM::LLVMVoidType::get(ctx);
256 Type ptrType = AsyncAPI::opaquePointerType(ctx);
257
258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259 kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260 resumeOp.setPrivate();
261
262 auto *block = resumeOp.addEntryBlock(moduleBuilder);
263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: block);
264
265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266 blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267}
268
269//===----------------------------------------------------------------------===//
270// Convert Async dialect types to LLVM types.
271//===----------------------------------------------------------------------===//
272
273namespace {
274/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275/// their runtime type (opaque pointers) and does not convert any other types.
276class AsyncRuntimeTypeConverter : public TypeConverter {
277public:
278 AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279 addConversion(callback: [](Type type) { return type; });
280 addConversion(callback: [](Type type) { return convertAsyncTypes(type); });
281
282 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283 // in patterns for other dialects.
284 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285 ValueRange inputs, Location loc) {
286 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287 return std::optional<Value>(cast.getResult(0));
288 };
289
290 addSourceMaterialization(addUnrealizedCast);
291 addTargetMaterialization(addUnrealizedCast);
292 }
293
294 static std::optional<Type> convertAsyncTypes(Type type) {
295 if (isa<TokenType, GroupType, ValueType>(type))
296 return AsyncAPI::opaquePointerType(type.getContext());
297
298 if (isa<CoroIdType, CoroStateType>(type))
299 return AsyncAPI::tokenType(ctx: type.getContext());
300 if (isa<CoroHandleType>(type))
301 return AsyncAPI::opaquePointerType(type.getContext());
302
303 return std::nullopt;
304 }
305};
306
307/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
308/// as type converter. Allows access to it via the 'getTypeConverter'
309/// convenience method.
310template <typename SourceOp>
311class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
312
313 using Base = OpConversionPattern<SourceOp>;
314
315public:
316 AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
317 MLIRContext *context)
318 : Base(typeConverter, context) {}
319
320 /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321 const AsyncRuntimeTypeConverter *getTypeConverter() const {
322 return static_cast<const AsyncRuntimeTypeConverter *>(
323 Base::getTypeConverter());
324 }
325};
326
327} // namespace
328
329//===----------------------------------------------------------------------===//
330// Convert async.coro.id to @llvm.coro.id intrinsic.
331//===----------------------------------------------------------------------===//
332
333namespace {
334class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
335public:
336 using AsyncOpConversionPattern::AsyncOpConversionPattern;
337
338 LogicalResult
339 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 auto token = AsyncAPI::tokenType(ctx: op->getContext());
342 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343 auto loc = op->getLoc();
344
345 // Constants for initializing coroutine frame.
346 auto constZero =
347 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
348 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
349
350 // Get coroutine id: @llvm.coro.id.
351 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
353
354 return success();
355 }
356};
357} // namespace
358
359//===----------------------------------------------------------------------===//
360// Convert async.coro.begin to @llvm.coro.begin intrinsic.
361//===----------------------------------------------------------------------===//
362
363namespace {
364class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
365public:
366 using AsyncOpConversionPattern::AsyncOpConversionPattern;
367
368 LogicalResult
369 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter) const override {
371 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372 auto loc = op->getLoc();
373
374 // Get coroutine frame size: @llvm.coro.size.i64.
375 Value coroSize =
376 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377 // Get coroutine frame alignment: @llvm.coro.align.i64.
378 Value coroAlign =
379 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380
381 // Round up the size to be multiple of the alignment. Since aligned_alloc
382 // requires the size parameter be an integral multiple of the alignment
383 // parameter.
384 auto makeConstant = [&](uint64_t c) {
385 return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
386 rewriter.getI64Type(), c);
387 };
388 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389 coroSize =
390 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391 Value negCoroAlign =
392 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393 coroSize =
394 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
395
396 // Allocate memory for the coroutine frame.
397 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398 moduleOp: op->getParentOfType<ModuleOp>(), indexType: rewriter.getI64Type());
399 auto coroAlloc = rewriter.create<LLVM::CallOp>(
400 loc, allocFuncOp, ValueRange{coroAlign, coroSize});
401
402 // Begin a coroutine: @llvm.coro.begin.
403 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
404 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
405 op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
406
407 return success();
408 }
409};
410} // namespace
411
412//===----------------------------------------------------------------------===//
413// Convert async.coro.free to @llvm.coro.free intrinsic.
414//===----------------------------------------------------------------------===//
415
416namespace {
417class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
418public:
419 using AsyncOpConversionPattern::AsyncOpConversionPattern;
420
421 LogicalResult
422 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const override {
424 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
425 auto loc = op->getLoc();
426
427 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
428 auto coroMem =
429 rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
430
431 // Free the memory.
432 auto freeFuncOp =
433 LLVM::lookupOrCreateFreeFn(moduleOp: op->getParentOfType<ModuleOp>());
434 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
435 ValueRange(coroMem.getResult()));
436
437 return success();
438 }
439};
440} // namespace
441
442//===----------------------------------------------------------------------===//
443// Convert async.coro.end to @llvm.coro.end intrinsic.
444//===----------------------------------------------------------------------===//
445
446namespace {
447class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
448public:
449 using OpConversionPattern::OpConversionPattern;
450
451 LogicalResult
452 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
453 ConversionPatternRewriter &rewriter) const override {
454 // We are not in the block that is part of the unwind sequence.
455 auto constFalse = rewriter.create<LLVM::ConstantOp>(
456 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
457 auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
458
459 // Mark the end of a coroutine: @llvm.coro.end.
460 auto coroHdl = adaptor.getHandle();
461 rewriter.create<LLVM::CoroEndOp>(
462 op->getLoc(), rewriter.getI1Type(),
463 ValueRange({coroHdl, constFalse, noneToken}));
464 rewriter.eraseOp(op: op);
465
466 return success();
467 }
468};
469} // namespace
470
471//===----------------------------------------------------------------------===//
472// Convert async.coro.save to @llvm.coro.save intrinsic.
473//===----------------------------------------------------------------------===//
474
475namespace {
476class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
477public:
478 using OpConversionPattern::OpConversionPattern;
479
480 LogicalResult
481 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
482 ConversionPatternRewriter &rewriter) const override {
483 // Save the coroutine state: @llvm.coro.save
484 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
485 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
486
487 return success();
488 }
489};
490} // namespace
491
492//===----------------------------------------------------------------------===//
493// Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
494//===----------------------------------------------------------------------===//
495
496namespace {
497
498/// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
499/// branch to the appropriate block based on the return code.
500///
501/// Before:
502///
503/// ^suspended:
504/// "opBefore"(...)
505/// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
506/// ^resume:
507/// "op"(...)
508/// ^cleanup: ...
509/// ^suspend: ...
510///
511/// After:
512///
513/// ^suspended:
514/// "opBefore"(...)
515/// %suspend = llmv.intr.coro.suspend ...
516/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
517/// ^resume:
518/// "op"(...)
519/// ^cleanup: ...
520/// ^suspend: ...
521///
522class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
523public:
524 using OpConversionPattern::OpConversionPattern;
525
526 LogicalResult
527 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 auto i8 = rewriter.getIntegerType(8);
530 auto i32 = rewriter.getI32Type();
531 auto loc = op->getLoc();
532
533 // This is not a final suspension point.
534 auto constFalse = rewriter.create<LLVM::ConstantOp>(
535 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
536
537 // Suspend a coroutine: @llvm.coro.suspend
538 auto coroState = adaptor.getState();
539 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
540 loc, i8, ValueRange({coroState, constFalse}));
541
542 // Cast return code to i32.
543
544 // After a suspension point decide if we should branch into resume, cleanup
545 // or suspend block of the coroutine (see @llvm.coro.suspend return code
546 // documentation).
547 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
548 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
549 op.getCleanupDest()};
550 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
551 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
552 /*defaultDestination=*/op.getSuspendDest(),
553 /*defaultOperands=*/ValueRange(),
554 /*caseValues=*/caseValues,
555 /*caseDestinations=*/caseDest,
556 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
557 /*branchWeights=*/ArrayRef<int32_t>());
558
559 return success();
560 }
561};
562} // namespace
563
564//===----------------------------------------------------------------------===//
565// Convert async.runtime.create to the corresponding runtime API call.
566//
567// To allocate storage for the async values we use getelementptr trick:
568// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
569//===----------------------------------------------------------------------===//
570
571namespace {
572class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
573public:
574 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
575
576 LogicalResult
577 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
578 ConversionPatternRewriter &rewriter) const override {
579 const TypeConverter *converter = getTypeConverter();
580 Type resultType = op->getResultTypes()[0];
581
582 // Tokens creation maps to a simple function call.
583 if (isa<TokenType>(resultType)) {
584 rewriter.replaceOpWithNewOp<func::CallOp>(
585 op, kCreateToken, converter->convertType(resultType));
586 return success();
587 }
588
589 // To create a value we need to compute the storage requirement.
590 if (auto value = dyn_cast<ValueType>(resultType)) {
591 // Returns the size requirements for the async value storage.
592 auto sizeOf = [&](ValueType valueType) -> Value {
593 auto loc = op->getLoc();
594 auto i64 = rewriter.getI64Type();
595
596 auto storedType = converter->convertType(valueType.getValueType());
597 auto storagePtrType =
598 AsyncAPI::opaquePointerType(rewriter.getContext());
599
600 // %Size = getelementptr %T* null, int 1
601 // %SizeI = ptrtoint %T* %Size to i64
602 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
603 auto gep =
604 rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
605 nullPtr, ArrayRef<LLVM::GEPArg>{1});
606 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
607 };
608
609 rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
610 sizeOf(value));
611
612 return success();
613 }
614
615 return rewriter.notifyMatchFailure(op, "unsupported async type");
616 }
617};
618} // namespace
619
620//===----------------------------------------------------------------------===//
621// Convert async.runtime.create_group to the corresponding runtime API call.
622//===----------------------------------------------------------------------===//
623
624namespace {
625class RuntimeCreateGroupOpLowering
626 : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
627public:
628 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
629
630 LogicalResult
631 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter) const override {
633 const TypeConverter *converter = getTypeConverter();
634 Type resultType = op.getResult().getType();
635
636 rewriter.replaceOpWithNewOp<func::CallOp>(
637 op, kCreateGroup, converter->convertType(resultType),
638 adaptor.getOperands());
639 return success();
640 }
641};
642} // namespace
643
644//===----------------------------------------------------------------------===//
645// Convert async.runtime.set_available to the corresponding runtime API call.
646//===----------------------------------------------------------------------===//
647
648namespace {
649class RuntimeSetAvailableOpLowering
650 : public OpConversionPattern<RuntimeSetAvailableOp> {
651public:
652 using OpConversionPattern::OpConversionPattern;
653
654 LogicalResult
655 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
656 ConversionPatternRewriter &rewriter) const override {
657 StringRef apiFuncName =
658 TypeSwitch<Type, StringRef>(op.getOperand().getType())
659 .Case<TokenType>([](Type) { return kEmplaceToken; })
660 .Case<ValueType>([](Type) { return kEmplaceValue; });
661
662 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
663 adaptor.getOperands());
664
665 return success();
666 }
667};
668} // namespace
669
670//===----------------------------------------------------------------------===//
671// Convert async.runtime.set_error to the corresponding runtime API call.
672//===----------------------------------------------------------------------===//
673
674namespace {
675class RuntimeSetErrorOpLowering
676 : public OpConversionPattern<RuntimeSetErrorOp> {
677public:
678 using OpConversionPattern::OpConversionPattern;
679
680 LogicalResult
681 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter) const override {
683 StringRef apiFuncName =
684 TypeSwitch<Type, StringRef>(op.getOperand().getType())
685 .Case<TokenType>([](Type) { return kSetTokenError; })
686 .Case<ValueType>([](Type) { return kSetValueError; });
687
688 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
689 adaptor.getOperands());
690
691 return success();
692 }
693};
694} // namespace
695
696//===----------------------------------------------------------------------===//
697// Convert async.runtime.is_error to the corresponding runtime API call.
698//===----------------------------------------------------------------------===//
699
700namespace {
701class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
702public:
703 using OpConversionPattern::OpConversionPattern;
704
705 LogicalResult
706 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
707 ConversionPatternRewriter &rewriter) const override {
708 StringRef apiFuncName =
709 TypeSwitch<Type, StringRef>(op.getOperand().getType())
710 .Case<TokenType>([](Type) { return kIsTokenError; })
711 .Case<GroupType>([](Type) { return kIsGroupError; })
712 .Case<ValueType>([](Type) { return kIsValueError; });
713
714 rewriter.replaceOpWithNewOp<func::CallOp>(
715 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
716 return success();
717 }
718};
719} // namespace
720
721//===----------------------------------------------------------------------===//
722// Convert async.runtime.await to the corresponding runtime API call.
723//===----------------------------------------------------------------------===//
724
725namespace {
726class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
727public:
728 using OpConversionPattern::OpConversionPattern;
729
730 LogicalResult
731 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
732 ConversionPatternRewriter &rewriter) const override {
733 StringRef apiFuncName =
734 TypeSwitch<Type, StringRef>(op.getOperand().getType())
735 .Case<TokenType>([](Type) { return kAwaitToken; })
736 .Case<ValueType>([](Type) { return kAwaitValue; })
737 .Case<GroupType>([](Type) { return kAwaitGroup; });
738
739 rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
740 adaptor.getOperands());
741 rewriter.eraseOp(op: op);
742
743 return success();
744 }
745};
746} // namespace
747
748//===----------------------------------------------------------------------===//
749// Convert async.runtime.await_and_resume to the corresponding runtime API call.
750//===----------------------------------------------------------------------===//
751
752namespace {
753class RuntimeAwaitAndResumeOpLowering
754 : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
755public:
756 using AsyncOpConversionPattern::AsyncOpConversionPattern;
757
758 LogicalResult
759 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
760 ConversionPatternRewriter &rewriter) const override {
761 StringRef apiFuncName =
762 TypeSwitch<Type, StringRef>(op.getOperand().getType())
763 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
764 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
765 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
766
767 Value operand = adaptor.getOperand();
768 Value handle = adaptor.getHandle();
769
770 // A pointer to coroutine resume intrinsic wrapper.
771 addResumeFunction(op->getParentOfType<ModuleOp>());
772 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
773 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
774 kResume);
775
776 rewriter.create<func::CallOp>(
777 op->getLoc(), apiFuncName, TypeRange(),
778 ValueRange({operand, handle, resumePtr.getRes()}));
779 rewriter.eraseOp(op: op);
780
781 return success();
782 }
783};
784} // namespace
785
786//===----------------------------------------------------------------------===//
787// Convert async.runtime.resume to the corresponding runtime API call.
788//===----------------------------------------------------------------------===//
789
790namespace {
791class RuntimeResumeOpLowering
792 : public AsyncOpConversionPattern<RuntimeResumeOp> {
793public:
794 using AsyncOpConversionPattern::AsyncOpConversionPattern;
795
796 LogicalResult
797 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
798 ConversionPatternRewriter &rewriter) const override {
799 // A pointer to coroutine resume intrinsic wrapper.
800 addResumeFunction(op->getParentOfType<ModuleOp>());
801 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
802 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
803 kResume);
804
805 // Call async runtime API to execute a coroutine in the managed thread.
806 auto coroHdl = adaptor.getHandle();
807 rewriter.replaceOpWithNewOp<func::CallOp>(
808 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
809
810 return success();
811 }
812};
813} // namespace
814
815//===----------------------------------------------------------------------===//
816// Convert async.runtime.store to the corresponding runtime API call.
817//===----------------------------------------------------------------------===//
818
819namespace {
820class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
821public:
822 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
823
824 LogicalResult
825 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
826 ConversionPatternRewriter &rewriter) const override {
827 Location loc = op->getLoc();
828
829 // Get a pointer to the async value storage from the runtime.
830 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
831 auto storage = adaptor.getStorage();
832 auto storagePtr = rewriter.create<func::CallOp>(
833 loc, kGetValueStorage, TypeRange(ptrType), storage);
834
835 // Cast from i8* to the LLVM pointer type.
836 auto valueType = op.getValue().getType();
837 auto llvmValueType = getTypeConverter()->convertType(valueType);
838 if (!llvmValueType)
839 return rewriter.notifyMatchFailure(
840 op, "failed to convert stored value type to LLVM type");
841
842 Value castedStoragePtr = storagePtr.getResult(0);
843 // Store the yielded value into the async value storage.
844 auto value = adaptor.getValue();
845 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
846
847 // Erase the original runtime store operation.
848 rewriter.eraseOp(op: op);
849
850 return success();
851 }
852};
853} // namespace
854
855//===----------------------------------------------------------------------===//
856// Convert async.runtime.load to the corresponding runtime API call.
857//===----------------------------------------------------------------------===//
858
859namespace {
860class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
861public:
862 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
863
864 LogicalResult
865 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
866 ConversionPatternRewriter &rewriter) const override {
867 Location loc = op->getLoc();
868
869 // Get a pointer to the async value storage from the runtime.
870 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
871 auto storage = adaptor.getStorage();
872 auto storagePtr = rewriter.create<func::CallOp>(
873 loc, kGetValueStorage, TypeRange(ptrType), storage);
874
875 // Cast from i8* to the LLVM pointer type.
876 auto valueType = op.getResult().getType();
877 auto llvmValueType = getTypeConverter()->convertType(valueType);
878 if (!llvmValueType)
879 return rewriter.notifyMatchFailure(
880 op, "failed to convert loaded value type to LLVM type");
881
882 Value castedStoragePtr = storagePtr.getResult(0);
883
884 // Load from the casted pointer.
885 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
886 castedStoragePtr);
887
888 return success();
889 }
890};
891} // namespace
892
893//===----------------------------------------------------------------------===//
894// Convert async.runtime.add_to_group to the corresponding runtime API call.
895//===----------------------------------------------------------------------===//
896
897namespace {
898class RuntimeAddToGroupOpLowering
899 : public OpConversionPattern<RuntimeAddToGroupOp> {
900public:
901 using OpConversionPattern::OpConversionPattern;
902
903 LogicalResult
904 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
905 ConversionPatternRewriter &rewriter) const override {
906 // Currently we can only add tokens to the group.
907 if (!isa<TokenType>(op.getOperand().getType()))
908 return rewriter.notifyMatchFailure(op, "only token type is supported");
909
910 // Replace with a runtime API function call.
911 rewriter.replaceOpWithNewOp<func::CallOp>(
912 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
913
914 return success();
915 }
916};
917} // namespace
918
919//===----------------------------------------------------------------------===//
920// Convert async.runtime.num_worker_threads to the corresponding runtime API
921// call.
922//===----------------------------------------------------------------------===//
923
924namespace {
925class RuntimeNumWorkerThreadsOpLowering
926 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
927public:
928 using OpConversionPattern::OpConversionPattern;
929
930 LogicalResult
931 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
932 ConversionPatternRewriter &rewriter) const override {
933
934 // Replace with a runtime API function call.
935 rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
936 rewriter.getIndexType());
937
938 return success();
939 }
940};
941} // namespace
942
943//===----------------------------------------------------------------------===//
944// Async reference counting ops lowering (`async.runtime.add_ref` and
945// `async.runtime.drop_ref` to the corresponding API calls).
946//===----------------------------------------------------------------------===//
947
948namespace {
949template <typename RefCountingOp>
950class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
951public:
952 explicit RefCountingOpLowering(const TypeConverter &converter,
953 MLIRContext *ctx, StringRef apiFunctionName)
954 : OpConversionPattern<RefCountingOp>(converter, ctx),
955 apiFunctionName(apiFunctionName) {}
956
957 LogicalResult
958 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
959 ConversionPatternRewriter &rewriter) const override {
960 auto count = rewriter.create<arith::ConstantOp>(
961 op->getLoc(), rewriter.getI64Type(),
962 rewriter.getI64IntegerAttr(op.getCount()));
963
964 auto operand = adaptor.getOperand();
965 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
966 ValueRange({operand, count}));
967
968 return success();
969 }
970
971private:
972 StringRef apiFunctionName;
973};
974
975class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
976public:
977 explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
978 MLIRContext *ctx)
979 : RefCountingOpLowering(converter, ctx, kAddRef) {}
980};
981
982class RuntimeDropRefOpLowering
983 : public RefCountingOpLowering<RuntimeDropRefOp> {
984public:
985 explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
986 MLIRContext *ctx)
987 : RefCountingOpLowering(converter, ctx, kDropRef) {}
988};
989} // namespace
990
991//===----------------------------------------------------------------------===//
992// Convert return operations that return async values from async regions.
993//===----------------------------------------------------------------------===//
994
995namespace {
996class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
997public:
998 using OpConversionPattern::OpConversionPattern;
999
1000 LogicalResult
1001 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter) const override {
1003 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1004 return success();
1005 }
1006};
1007} // namespace
1008
1009//===----------------------------------------------------------------------===//
1010
1011namespace {
1012struct ConvertAsyncToLLVMPass
1013 : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1014 using Base::Base;
1015
1016 void runOnOperation() override;
1017};
1018} // namespace
1019
1020void ConvertAsyncToLLVMPass::runOnOperation() {
1021 ModuleOp module = getOperation();
1022 MLIRContext *ctx = module->getContext();
1023
1024 LowerToLLVMOptions options(ctx);
1025
1026 // Add declarations for most functions required by the coroutines lowering.
1027 // We delay adding the resume function until it's needed because it currently
1028 // fails to compile unless '-O0' is specified.
1029 addAsyncRuntimeApiDeclarations(module);
1030
1031 // Lower async.runtime and async.coro operations to Async Runtime API and
1032 // LLVM coroutine intrinsics.
1033
1034 // Convert async dialect types and operations to LLVM dialect.
1035 AsyncRuntimeTypeConverter converter(options);
1036 RewritePatternSet patterns(ctx);
1037
1038 // We use conversion to LLVM type to lower async.runtime load and store
1039 // operations.
1040 LLVMTypeConverter llvmConverter(ctx, options);
1041 llvmConverter.addConversion(callback: [&](Type type) {
1042 return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1043 });
1044
1045 // Convert async types in function signatures and function calls.
1046 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1047 converter);
1048 populateCallOpTypeConversionPattern(patterns, converter);
1049
1050 // Convert return operations inside async.execute regions.
1051 patterns.add<ReturnOpOpConversion>(arg&: converter, args&: ctx);
1052
1053 // Lower async.runtime operations to the async runtime API calls.
1054 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1055 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1056 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1057 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1058 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(arg&: converter,
1059 args&: ctx);
1060
1061 // Lower async.runtime operations that rely on LLVM type converter to convert
1062 // from async value payload type to the LLVM type.
1063 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1064 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(arg&: llvmConverter);
1065
1066 // Lower async coroutine operations to LLVM coroutine intrinsics.
1067 patterns
1068 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1069 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1070 arg&: converter, args&: ctx);
1071
1072 ConversionTarget target(*ctx);
1073 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1074 UnrealizedConversionCastOp>();
1075 target.addLegalDialect<LLVM::LLVMDialect>();
1076
1077 // All operations from Async dialect must be lowered to the runtime API and
1078 // LLVM intrinsics calls.
1079 target.addIllegalDialect<AsyncDialect>();
1080
1081 // Add dynamic legality constraints to apply conversions defined above.
1082 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1083 return converter.isSignatureLegal(op.getFunctionType());
1084 });
1085 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1086 return converter.isLegal(op.getOperandTypes());
1087 });
1088 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1089 return converter.isSignatureLegal(op.getCalleeType());
1090 });
1091
1092 if (failed(applyPartialConversion(module, target, std::move(patterns))))
1093 signalPassFailure();
1094}
1095
1096//===----------------------------------------------------------------------===//
1097// Patterns for structural type conversions for the Async dialect operations.
1098//===----------------------------------------------------------------------===//
1099
1100namespace {
1101class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1102public:
1103 using OpConversionPattern::OpConversionPattern;
1104 LogicalResult
1105 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter) const override {
1107 ExecuteOp newOp =
1108 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1109 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1110 newOp.getRegion().end());
1111
1112 // Set operands and update block argument and result types.
1113 newOp->setOperands(adaptor.getOperands());
1114 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1115 return failure();
1116 for (auto result : newOp.getResults())
1117 result.setType(typeConverter->convertType(result.getType()));
1118
1119 rewriter.replaceOp(op, newOp.getResults());
1120 return success();
1121 }
1122};
1123
1124// Dummy pattern to trigger the appropriate type conversion / materialization.
1125class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1126public:
1127 using OpConversionPattern::OpConversionPattern;
1128 LogicalResult
1129 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1130 ConversionPatternRewriter &rewriter) const override {
1131 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1132 return success();
1133 }
1134};
1135
1136// Dummy pattern to trigger the appropriate type conversion / materialization.
1137class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1138public:
1139 using OpConversionPattern::OpConversionPattern;
1140 LogicalResult
1141 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1142 ConversionPatternRewriter &rewriter) const override {
1143 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1144 return success();
1145 }
1146};
1147} // namespace
1148
1149void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1150 TypeConverter &typeConverter, RewritePatternSet &patterns,
1151 ConversionTarget &target) {
1152 typeConverter.addConversion([&](TokenType type) { return type; });
1153 typeConverter.addConversion(callback: [&](ValueType type) {
1154 Type converted = typeConverter.convertType(type.getValueType());
1155 return converted ? ValueType::get(converted) : converted;
1156 });
1157
1158 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1159 arg&: typeConverter, args: patterns.getContext());
1160
1161 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1162 [&](Operation *op) { return typeConverter.isLegal(op); });
1163}
1164

source code of mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp