1//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Async/IR/Async.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/IR/ImplicitLocOpBuilder.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33#define DEBUG_TYPE "convert-async-to-llvm"
34
35using namespace mlir;
36using namespace mlir::async;
37
38//===----------------------------------------------------------------------===//
39// Async Runtime C API declaration.
40//===----------------------------------------------------------------------===//
41
42static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
44static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
47static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
49static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
50static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
51static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
52static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
54static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
57static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58static constexpr const char *kGetValueStorage =
59 "mlirAsyncRuntimeGetValueStorage";
60static constexpr const char *kAddTokenToGroup =
61 "mlirAsyncRuntimeAddTokenToGroup";
62static constexpr const char *kAwaitTokenAndExecute =
63 "mlirAsyncRuntimeAwaitTokenAndExecute";
64static constexpr const char *kAwaitValueAndExecute =
65 "mlirAsyncRuntimeAwaitValueAndExecute";
66static constexpr const char *kAwaitAllAndExecute =
67 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68static constexpr const char *kGetNumWorkerThreads =
69 "mlirAsyncRuntimGetNumWorkerThreads";
70
71namespace {
72/// Async Runtime API function types.
73///
74/// Because we can't create API function signature for type parametrized
75/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76/// lowering all async data types become opaque pointers at runtime.
77struct AsyncAPI {
78 // All async types are lowered to opaque LLVM pointers at runtime.
79 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
80 return LLVM::LLVMPointerType::get(ctx);
81 }
82
83 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
84 return LLVM::LLVMTokenType::get(ctx);
85 }
86
87 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88 auto ref = opaquePointerType(ctx);
89 auto count = IntegerType::get(ctx, 64);
90 return FunctionType::get(ctx, {ref, count}, {});
91 }
92
93 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
94 return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95 }
96
97 static FunctionType createValueFunctionType(MLIRContext *ctx) {
98 auto i64 = IntegerType::get(ctx, 64);
99 auto value = opaquePointerType(ctx);
100 return FunctionType::get(ctx, {i64}, {value});
101 }
102
103 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104 auto i64 = IntegerType::get(ctx, 64);
105 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106 }
107
108 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109 auto ptrType = opaquePointerType(ctx);
110 return FunctionType::get(ctx, {ptrType}, {ptrType});
111 }
112
113 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
114 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115 }
116
117 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118 auto value = opaquePointerType(ctx);
119 return FunctionType::get(ctx, {value}, {});
120 }
121
122 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
123 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124 }
125
126 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127 auto value = opaquePointerType(ctx);
128 return FunctionType::get(ctx, {value}, {});
129 }
130
131 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
132 auto i1 = IntegerType::get(ctx, 1);
133 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134 }
135
136 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137 auto value = opaquePointerType(ctx);
138 auto i1 = IntegerType::get(ctx, 1);
139 return FunctionType::get(ctx, {value}, {i1});
140 }
141
142 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143 auto i1 = IntegerType::get(ctx, 1);
144 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145 }
146
147 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
148 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149 }
150
151 static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152 auto value = opaquePointerType(ctx);
153 return FunctionType::get(ctx, {value}, {});
154 }
155
156 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
157 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158 }
159
160 static FunctionType executeFunctionType(MLIRContext *ctx) {
161 auto ptrType = opaquePointerType(ctx);
162 return FunctionType::get(ctx, {ptrType, ptrType}, {});
163 }
164
165 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166 auto i64 = IntegerType::get(ctx, 64);
167 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168 {i64});
169 }
170
171 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172 auto ptrType = opaquePointerType(ctx);
173 return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174 }
175
176 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177 auto ptrType = opaquePointerType(ctx);
178 return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179 }
180
181 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182 auto ptrType = opaquePointerType(ctx);
183 return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184 }
185
186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188 }
189
190 // Auxiliary coroutine resume intrinsic wrapper.
191 static Type resumeFunctionType(MLIRContext *ctx) {
192 auto voidTy = LLVM::LLVMVoidType::get(ctx);
193 auto ptrType = opaquePointerType(ctx);
194 return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
195 }
196};
197} // namespace
198
199/// Adds Async Runtime C API declarations to the module.
200static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201 auto builder =
202 ImplicitLocOpBuilder::atBlockEnd(loc: module.getLoc(), block: module.getBody());
203
204 auto addFuncDecl = [&](StringRef name, FunctionType type) {
205 if (module.lookupSymbol(name))
206 return;
207 builder.create<func::FuncOp>(name, type).setPrivate();
208 };
209
210 MLIRContext *ctx = module.getContext();
211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229 addFuncDecl(kAwaitTokenAndExecute,
230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231 addFuncDecl(kAwaitValueAndExecute,
232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233 addFuncDecl(kAwaitAllAndExecute,
234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236}
237
238//===----------------------------------------------------------------------===//
239// Coroutine resume function wrapper.
240//===----------------------------------------------------------------------===//
241
242static constexpr const char *kResume = "__resume";
243
244/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245/// intrinsics. We need this function to be able to pass it to the async
246/// runtime execute API.
247static void addResumeFunction(ModuleOp module) {
248 if (module.lookupSymbol(kResume))
249 return;
250
251 MLIRContext *ctx = module.getContext();
252 auto loc = module.getLoc();
253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: module.getBody());
254
255 auto voidTy = LLVM::LLVMVoidType::get(ctx);
256 Type ptrType = AsyncAPI::opaquePointerType(ctx);
257
258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259 kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260 resumeOp.setPrivate();
261
262 auto *block = resumeOp.addEntryBlock(moduleBuilder);
263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: block);
264
265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266 blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267}
268
269//===----------------------------------------------------------------------===//
270// Convert Async dialect types to LLVM types.
271//===----------------------------------------------------------------------===//
272
273namespace {
274/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275/// their runtime type (opaque pointers) and does not convert any other types.
276class AsyncRuntimeTypeConverter : public TypeConverter {
277public:
278 AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279 addConversion(callback: [](Type type) { return type; });
280 addConversion(callback: [](Type type) { return convertAsyncTypes(type); });
281
282 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283 // in patterns for other dialects.
284 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285 ValueRange inputs, Location loc) -> Value {
286 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287 return cast.getResult(0);
288 };
289
290 addSourceMaterialization(callback&: addUnrealizedCast);
291 addTargetMaterialization(callback&: addUnrealizedCast);
292 }
293
294 static std::optional<Type> convertAsyncTypes(Type type) {
295 if (isa<TokenType, GroupType, ValueType>(type))
296 return AsyncAPI::opaquePointerType(type.getContext());
297
298 if (isa<CoroIdType, CoroStateType>(type))
299 return AsyncAPI::tokenType(ctx: type.getContext());
300 if (isa<CoroHandleType>(type))
301 return AsyncAPI::opaquePointerType(type.getContext());
302
303 return std::nullopt;
304 }
305};
306
307/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
308/// as type converter. Allows access to it via the 'getTypeConverter'
309/// convenience method.
310template <typename SourceOp>
311class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
312
313 using Base = OpConversionPattern<SourceOp>;
314
315public:
316 AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
317 MLIRContext *context)
318 : Base(typeConverter, context) {}
319
320 /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321 const AsyncRuntimeTypeConverter *getTypeConverter() const {
322 return static_cast<const AsyncRuntimeTypeConverter *>(
323 Base::getTypeConverter());
324 }
325};
326
327} // namespace
328
329//===----------------------------------------------------------------------===//
330// Convert async.coro.id to @llvm.coro.id intrinsic.
331//===----------------------------------------------------------------------===//
332
333namespace {
334class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
335public:
336 using AsyncOpConversionPattern::AsyncOpConversionPattern;
337
338 LogicalResult
339 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 auto token = AsyncAPI::tokenType(ctx: op->getContext());
342 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343 auto loc = op->getLoc();
344
345 // Constants for initializing coroutine frame.
346 auto constZero =
347 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
348 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
349
350 // Get coroutine id: @llvm.coro.id.
351 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
353
354 return success();
355 }
356};
357} // namespace
358
359//===----------------------------------------------------------------------===//
360// Convert async.coro.begin to @llvm.coro.begin intrinsic.
361//===----------------------------------------------------------------------===//
362
363namespace {
364class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
365public:
366 using AsyncOpConversionPattern::AsyncOpConversionPattern;
367
368 LogicalResult
369 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter) const override {
371 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372 auto loc = op->getLoc();
373
374 // Get coroutine frame size: @llvm.coro.size.i64.
375 Value coroSize =
376 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377 // Get coroutine frame alignment: @llvm.coro.align.i64.
378 Value coroAlign =
379 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380
381 // Round up the size to be multiple of the alignment. Since aligned_alloc
382 // requires the size parameter be an integral multiple of the alignment
383 // parameter.
384 auto makeConstant = [&](uint64_t c) {
385 return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
386 rewriter.getI64Type(), c);
387 };
388 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389 coroSize =
390 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391 Value negCoroAlign =
392 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393 coroSize =
394 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
395
396 // Allocate memory for the coroutine frame.
397 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398 b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>(), indexType: rewriter.getI64Type());
399 if (failed(allocFuncOp))
400 return failure();
401 auto coroAlloc = rewriter.create<LLVM::CallOp>(
402 loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
403
404 // Begin a coroutine: @llvm.coro.begin.
405 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
406 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
407 op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
408
409 return success();
410 }
411};
412} // namespace
413
414//===----------------------------------------------------------------------===//
415// Convert async.coro.free to @llvm.coro.free intrinsic.
416//===----------------------------------------------------------------------===//
417
418namespace {
419class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
420public:
421 using AsyncOpConversionPattern::AsyncOpConversionPattern;
422
423 LogicalResult
424 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
427 auto loc = op->getLoc();
428
429 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
430 auto coroMem =
431 rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
432
433 // Free the memory.
434 auto freeFuncOp =
435 LLVM::lookupOrCreateFreeFn(b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>());
436 if (failed(freeFuncOp))
437 return failure();
438 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
439 ValueRange(coroMem.getResult()));
440
441 return success();
442 }
443};
444} // namespace
445
446//===----------------------------------------------------------------------===//
447// Convert async.coro.end to @llvm.coro.end intrinsic.
448//===----------------------------------------------------------------------===//
449
450namespace {
451class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
452public:
453 using OpConversionPattern::OpConversionPattern;
454
455 LogicalResult
456 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 // We are not in the block that is part of the unwind sequence.
459 auto constFalse = rewriter.create<LLVM::ConstantOp>(
460 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
461 auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
462
463 // Mark the end of a coroutine: @llvm.coro.end.
464 auto coroHdl = adaptor.getHandle();
465 rewriter.create<LLVM::CoroEndOp>(
466 op->getLoc(), rewriter.getI1Type(),
467 ValueRange({coroHdl, constFalse, noneToken}));
468 rewriter.eraseOp(op: op);
469
470 return success();
471 }
472};
473} // namespace
474
475//===----------------------------------------------------------------------===//
476// Convert async.coro.save to @llvm.coro.save intrinsic.
477//===----------------------------------------------------------------------===//
478
479namespace {
480class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
481public:
482 using OpConversionPattern::OpConversionPattern;
483
484 LogicalResult
485 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const override {
487 // Save the coroutine state: @llvm.coro.save
488 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
489 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
490
491 return success();
492 }
493};
494} // namespace
495
496//===----------------------------------------------------------------------===//
497// Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
498//===----------------------------------------------------------------------===//
499
500namespace {
501
502/// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
503/// branch to the appropriate block based on the return code.
504///
505/// Before:
506///
507/// ^suspended:
508/// "opBefore"(...)
509/// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
510/// ^resume:
511/// "op"(...)
512/// ^cleanup: ...
513/// ^suspend: ...
514///
515/// After:
516///
517/// ^suspended:
518/// "opBefore"(...)
519/// %suspend = llmv.intr.coro.suspend ...
520/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
521/// ^resume:
522/// "op"(...)
523/// ^cleanup: ...
524/// ^suspend: ...
525///
526class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
527public:
528 using OpConversionPattern::OpConversionPattern;
529
530 LogicalResult
531 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter) const override {
533 auto i8 = rewriter.getIntegerType(8);
534 auto i32 = rewriter.getI32Type();
535 auto loc = op->getLoc();
536
537 // This is not a final suspension point.
538 auto constFalse = rewriter.create<LLVM::ConstantOp>(
539 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
540
541 // Suspend a coroutine: @llvm.coro.suspend
542 auto coroState = adaptor.getState();
543 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
544 loc, i8, ValueRange({coroState, constFalse}));
545
546 // Cast return code to i32.
547
548 // After a suspension point decide if we should branch into resume, cleanup
549 // or suspend block of the coroutine (see @llvm.coro.suspend return code
550 // documentation).
551 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
552 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
553 op.getCleanupDest()};
554 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
555 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
556 /*defaultDestination=*/op.getSuspendDest(),
557 /*defaultOperands=*/ValueRange(),
558 /*caseValues=*/caseValues,
559 /*caseDestinations=*/caseDest,
560 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
561 /*branchWeights=*/ArrayRef<int32_t>());
562
563 return success();
564 }
565};
566} // namespace
567
568//===----------------------------------------------------------------------===//
569// Convert async.runtime.create to the corresponding runtime API call.
570//
571// To allocate storage for the async values we use getelementptr trick:
572// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
573//===----------------------------------------------------------------------===//
574
575namespace {
576class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
577public:
578 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
579
580 LogicalResult
581 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
582 ConversionPatternRewriter &rewriter) const override {
583 const TypeConverter *converter = getTypeConverter();
584 Type resultType = op->getResultTypes()[0];
585
586 // Tokens creation maps to a simple function call.
587 if (isa<TokenType>(resultType)) {
588 rewriter.replaceOpWithNewOp<func::CallOp>(
589 op, kCreateToken, converter->convertType(resultType));
590 return success();
591 }
592
593 // To create a value we need to compute the storage requirement.
594 if (auto value = dyn_cast<ValueType>(resultType)) {
595 // Returns the size requirements for the async value storage.
596 auto sizeOf = [&](ValueType valueType) -> Value {
597 auto loc = op->getLoc();
598 auto i64 = rewriter.getI64Type();
599
600 auto storedType = converter->convertType(valueType.getValueType());
601 auto storagePtrType =
602 AsyncAPI::opaquePointerType(rewriter.getContext());
603
604 // %Size = getelementptr %T* null, int 1
605 // %SizeI = ptrtoint %T* %Size to i64
606 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
607 auto gep =
608 rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
609 nullPtr, ArrayRef<LLVM::GEPArg>{1});
610 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
611 };
612
613 rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
614 sizeOf(value));
615
616 return success();
617 }
618
619 return rewriter.notifyMatchFailure(op, "unsupported async type");
620 }
621};
622} // namespace
623
624//===----------------------------------------------------------------------===//
625// Convert async.runtime.create_group to the corresponding runtime API call.
626//===----------------------------------------------------------------------===//
627
628namespace {
629class RuntimeCreateGroupOpLowering
630 : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
631public:
632 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
633
634 LogicalResult
635 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
636 ConversionPatternRewriter &rewriter) const override {
637 const TypeConverter *converter = getTypeConverter();
638 Type resultType = op.getResult().getType();
639
640 rewriter.replaceOpWithNewOp<func::CallOp>(
641 op, kCreateGroup, converter->convertType(resultType),
642 adaptor.getOperands());
643 return success();
644 }
645};
646} // namespace
647
648//===----------------------------------------------------------------------===//
649// Convert async.runtime.set_available to the corresponding runtime API call.
650//===----------------------------------------------------------------------===//
651
652namespace {
653class RuntimeSetAvailableOpLowering
654 : public OpConversionPattern<RuntimeSetAvailableOp> {
655public:
656 using OpConversionPattern::OpConversionPattern;
657
658 LogicalResult
659 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter) const override {
661 StringRef apiFuncName =
662 TypeSwitch<Type, StringRef>(op.getOperand().getType())
663 .Case<TokenType>([](Type) { return kEmplaceToken; })
664 .Case<ValueType>([](Type) { return kEmplaceValue; });
665
666 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
667 adaptor.getOperands());
668
669 return success();
670 }
671};
672} // namespace
673
674//===----------------------------------------------------------------------===//
675// Convert async.runtime.set_error to the corresponding runtime API call.
676//===----------------------------------------------------------------------===//
677
678namespace {
679class RuntimeSetErrorOpLowering
680 : public OpConversionPattern<RuntimeSetErrorOp> {
681public:
682 using OpConversionPattern::OpConversionPattern;
683
684 LogicalResult
685 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) const override {
687 StringRef apiFuncName =
688 TypeSwitch<Type, StringRef>(op.getOperand().getType())
689 .Case<TokenType>([](Type) { return kSetTokenError; })
690 .Case<ValueType>([](Type) { return kSetValueError; });
691
692 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
693 adaptor.getOperands());
694
695 return success();
696 }
697};
698} // namespace
699
700//===----------------------------------------------------------------------===//
701// Convert async.runtime.is_error to the corresponding runtime API call.
702//===----------------------------------------------------------------------===//
703
704namespace {
705class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
706public:
707 using OpConversionPattern::OpConversionPattern;
708
709 LogicalResult
710 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
711 ConversionPatternRewriter &rewriter) const override {
712 StringRef apiFuncName =
713 TypeSwitch<Type, StringRef>(op.getOperand().getType())
714 .Case<TokenType>([](Type) { return kIsTokenError; })
715 .Case<GroupType>([](Type) { return kIsGroupError; })
716 .Case<ValueType>([](Type) { return kIsValueError; });
717
718 rewriter.replaceOpWithNewOp<func::CallOp>(
719 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
720 return success();
721 }
722};
723} // namespace
724
725//===----------------------------------------------------------------------===//
726// Convert async.runtime.await to the corresponding runtime API call.
727//===----------------------------------------------------------------------===//
728
729namespace {
730class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
731public:
732 using OpConversionPattern::OpConversionPattern;
733
734 LogicalResult
735 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 StringRef apiFuncName =
738 TypeSwitch<Type, StringRef>(op.getOperand().getType())
739 .Case<TokenType>([](Type) { return kAwaitToken; })
740 .Case<ValueType>([](Type) { return kAwaitValue; })
741 .Case<GroupType>([](Type) { return kAwaitGroup; });
742
743 rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
744 adaptor.getOperands());
745 rewriter.eraseOp(op: op);
746
747 return success();
748 }
749};
750} // namespace
751
752//===----------------------------------------------------------------------===//
753// Convert async.runtime.await_and_resume to the corresponding runtime API call.
754//===----------------------------------------------------------------------===//
755
756namespace {
757class RuntimeAwaitAndResumeOpLowering
758 : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
759public:
760 using AsyncOpConversionPattern::AsyncOpConversionPattern;
761
762 LogicalResult
763 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter) const override {
765 StringRef apiFuncName =
766 TypeSwitch<Type, StringRef>(op.getOperand().getType())
767 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
768 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
769 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
770
771 Value operand = adaptor.getOperand();
772 Value handle = adaptor.getHandle();
773
774 // A pointer to coroutine resume intrinsic wrapper.
775 addResumeFunction(op->getParentOfType<ModuleOp>());
776 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
777 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
778 kResume);
779
780 rewriter.create<func::CallOp>(
781 op->getLoc(), apiFuncName, TypeRange(),
782 ValueRange({operand, handle, resumePtr.getRes()}));
783 rewriter.eraseOp(op: op);
784
785 return success();
786 }
787};
788} // namespace
789
790//===----------------------------------------------------------------------===//
791// Convert async.runtime.resume to the corresponding runtime API call.
792//===----------------------------------------------------------------------===//
793
794namespace {
795class RuntimeResumeOpLowering
796 : public AsyncOpConversionPattern<RuntimeResumeOp> {
797public:
798 using AsyncOpConversionPattern::AsyncOpConversionPattern;
799
800 LogicalResult
801 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
802 ConversionPatternRewriter &rewriter) const override {
803 // A pointer to coroutine resume intrinsic wrapper.
804 addResumeFunction(op->getParentOfType<ModuleOp>());
805 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
806 op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
807 kResume);
808
809 // Call async runtime API to execute a coroutine in the managed thread.
810 auto coroHdl = adaptor.getHandle();
811 rewriter.replaceOpWithNewOp<func::CallOp>(
812 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
813
814 return success();
815 }
816};
817} // namespace
818
819//===----------------------------------------------------------------------===//
820// Convert async.runtime.store to the corresponding runtime API call.
821//===----------------------------------------------------------------------===//
822
823namespace {
824class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
825public:
826 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
827
828 LogicalResult
829 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
830 ConversionPatternRewriter &rewriter) const override {
831 Location loc = op->getLoc();
832
833 // Get a pointer to the async value storage from the runtime.
834 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
835 auto storage = adaptor.getStorage();
836 auto storagePtr = rewriter.create<func::CallOp>(
837 loc, kGetValueStorage, TypeRange(ptrType), storage);
838
839 // Cast from i8* to the LLVM pointer type.
840 auto valueType = op.getValue().getType();
841 auto llvmValueType = getTypeConverter()->convertType(valueType);
842 if (!llvmValueType)
843 return rewriter.notifyMatchFailure(
844 op, "failed to convert stored value type to LLVM type");
845
846 Value castedStoragePtr = storagePtr.getResult(0);
847 // Store the yielded value into the async value storage.
848 auto value = adaptor.getValue();
849 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
850
851 // Erase the original runtime store operation.
852 rewriter.eraseOp(op: op);
853
854 return success();
855 }
856};
857} // namespace
858
859//===----------------------------------------------------------------------===//
860// Convert async.runtime.load to the corresponding runtime API call.
861//===----------------------------------------------------------------------===//
862
863namespace {
864class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
865public:
866 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
867
868 LogicalResult
869 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
870 ConversionPatternRewriter &rewriter) const override {
871 Location loc = op->getLoc();
872
873 // Get a pointer to the async value storage from the runtime.
874 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
875 auto storage = adaptor.getStorage();
876 auto storagePtr = rewriter.create<func::CallOp>(
877 loc, kGetValueStorage, TypeRange(ptrType), storage);
878
879 // Cast from i8* to the LLVM pointer type.
880 auto valueType = op.getResult().getType();
881 auto llvmValueType = getTypeConverter()->convertType(valueType);
882 if (!llvmValueType)
883 return rewriter.notifyMatchFailure(
884 op, "failed to convert loaded value type to LLVM type");
885
886 Value castedStoragePtr = storagePtr.getResult(0);
887
888 // Load from the casted pointer.
889 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
890 castedStoragePtr);
891
892 return success();
893 }
894};
895} // namespace
896
897//===----------------------------------------------------------------------===//
898// Convert async.runtime.add_to_group to the corresponding runtime API call.
899//===----------------------------------------------------------------------===//
900
901namespace {
902class RuntimeAddToGroupOpLowering
903 : public OpConversionPattern<RuntimeAddToGroupOp> {
904public:
905 using OpConversionPattern::OpConversionPattern;
906
907 LogicalResult
908 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
909 ConversionPatternRewriter &rewriter) const override {
910 // Currently we can only add tokens to the group.
911 if (!isa<TokenType>(op.getOperand().getType()))
912 return rewriter.notifyMatchFailure(op, "only token type is supported");
913
914 // Replace with a runtime API function call.
915 rewriter.replaceOpWithNewOp<func::CallOp>(
916 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
917
918 return success();
919 }
920};
921} // namespace
922
923//===----------------------------------------------------------------------===//
924// Convert async.runtime.num_worker_threads to the corresponding runtime API
925// call.
926//===----------------------------------------------------------------------===//
927
928namespace {
929class RuntimeNumWorkerThreadsOpLowering
930 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
931public:
932 using OpConversionPattern::OpConversionPattern;
933
934 LogicalResult
935 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
936 ConversionPatternRewriter &rewriter) const override {
937
938 // Replace with a runtime API function call.
939 rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
940 rewriter.getIndexType());
941
942 return success();
943 }
944};
945} // namespace
946
947//===----------------------------------------------------------------------===//
948// Async reference counting ops lowering (`async.runtime.add_ref` and
949// `async.runtime.drop_ref` to the corresponding API calls).
950//===----------------------------------------------------------------------===//
951
952namespace {
953template <typename RefCountingOp>
954class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
955public:
956 explicit RefCountingOpLowering(const TypeConverter &converter,
957 MLIRContext *ctx, StringRef apiFunctionName)
958 : OpConversionPattern<RefCountingOp>(converter, ctx),
959 apiFunctionName(apiFunctionName) {}
960
961 LogicalResult
962 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
963 ConversionPatternRewriter &rewriter) const override {
964 auto count = rewriter.create<arith::ConstantOp>(
965 op->getLoc(), rewriter.getI64Type(),
966 rewriter.getI64IntegerAttr(op.getCount()));
967
968 auto operand = adaptor.getOperand();
969 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
970 ValueRange({operand, count}));
971
972 return success();
973 }
974
975private:
976 StringRef apiFunctionName;
977};
978
979class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
980public:
981 explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
982 MLIRContext *ctx)
983 : RefCountingOpLowering(converter, ctx, kAddRef) {}
984};
985
986class RuntimeDropRefOpLowering
987 : public RefCountingOpLowering<RuntimeDropRefOp> {
988public:
989 explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
990 MLIRContext *ctx)
991 : RefCountingOpLowering(converter, ctx, kDropRef) {}
992};
993} // namespace
994
995//===----------------------------------------------------------------------===//
996// Convert return operations that return async values from async regions.
997//===----------------------------------------------------------------------===//
998
999namespace {
1000class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
1001public:
1002 using OpConversionPattern::OpConversionPattern;
1003
1004 LogicalResult
1005 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1006 ConversionPatternRewriter &rewriter) const override {
1007 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1008 return success();
1009 }
1010};
1011} // namespace
1012
1013//===----------------------------------------------------------------------===//
1014
1015namespace {
1016struct ConvertAsyncToLLVMPass
1017 : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1018 using Base::Base;
1019
1020 void runOnOperation() override;
1021};
1022} // namespace
1023
1024void ConvertAsyncToLLVMPass::runOnOperation() {
1025 ModuleOp module = getOperation();
1026 MLIRContext *ctx = module->getContext();
1027
1028 LowerToLLVMOptions options(ctx);
1029
1030 // Add declarations for most functions required by the coroutines lowering.
1031 // We delay adding the resume function until it's needed because it currently
1032 // fails to compile unless '-O0' is specified.
1033 addAsyncRuntimeApiDeclarations(module);
1034
1035 // Lower async.runtime and async.coro operations to Async Runtime API and
1036 // LLVM coroutine intrinsics.
1037
1038 // Convert async dialect types and operations to LLVM dialect.
1039 AsyncRuntimeTypeConverter converter(options);
1040 RewritePatternSet patterns(ctx);
1041
1042 // We use conversion to LLVM type to lower async.runtime load and store
1043 // operations.
1044 LLVMTypeConverter llvmConverter(ctx, options);
1045 llvmConverter.addConversion(callback: [&](Type type) {
1046 return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1047 });
1048
1049 // Convert async types in function signatures and function calls.
1050 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1051 converter);
1052 populateCallOpTypeConversionPattern(patterns, converter);
1053
1054 // Convert return operations inside async.execute regions.
1055 patterns.add<ReturnOpOpConversion>(arg&: converter, args&: ctx);
1056
1057 // Lower async.runtime operations to the async runtime API calls.
1058 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1059 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1060 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1061 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1062 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(arg&: converter,
1063 args&: ctx);
1064
1065 // Lower async.runtime operations that rely on LLVM type converter to convert
1066 // from async value payload type to the LLVM type.
1067 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1068 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(arg&: llvmConverter);
1069
1070 // Lower async coroutine operations to LLVM coroutine intrinsics.
1071 patterns
1072 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1073 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1074 arg&: converter, args&: ctx);
1075
1076 ConversionTarget target(*ctx);
1077 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1078 UnrealizedConversionCastOp>();
1079 target.addLegalDialect<LLVM::LLVMDialect>();
1080
1081 // All operations from Async dialect must be lowered to the runtime API and
1082 // LLVM intrinsics calls.
1083 target.addIllegalDialect<AsyncDialect>();
1084
1085 // Add dynamic legality constraints to apply conversions defined above.
1086 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1087 return converter.isSignatureLegal(op.getFunctionType());
1088 });
1089 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1090 return converter.isLegal(op.getOperandTypes());
1091 });
1092 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1093 return converter.isSignatureLegal(op.getCalleeType());
1094 });
1095
1096 if (failed(applyPartialConversion(module, target, std::move(patterns))))
1097 signalPassFailure();
1098}
1099
1100//===----------------------------------------------------------------------===//
1101// Patterns for structural type conversions for the Async dialect operations.
1102//===----------------------------------------------------------------------===//
1103
1104namespace {
1105class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1106public:
1107 using OpConversionPattern::OpConversionPattern;
1108 LogicalResult
1109 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1110 ConversionPatternRewriter &rewriter) const override {
1111 ExecuteOp newOp =
1112 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1113 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1114 newOp.getRegion().end());
1115
1116 // Set operands and update block argument and result types.
1117 newOp->setOperands(adaptor.getOperands());
1118 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1119 return failure();
1120 for (auto result : newOp.getResults())
1121 result.setType(typeConverter->convertType(result.getType()));
1122
1123 rewriter.replaceOp(op, newOp.getResults());
1124 return success();
1125 }
1126};
1127
1128// Dummy pattern to trigger the appropriate type conversion / materialization.
1129class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1130public:
1131 using OpConversionPattern::OpConversionPattern;
1132 LogicalResult
1133 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1134 ConversionPatternRewriter &rewriter) const override {
1135 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1136 return success();
1137 }
1138};
1139
1140// Dummy pattern to trigger the appropriate type conversion / materialization.
1141class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1142public:
1143 using OpConversionPattern::OpConversionPattern;
1144 LogicalResult
1145 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const override {
1147 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1148 return success();
1149 }
1150};
1151} // namespace
1152
1153void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1154 TypeConverter &typeConverter, RewritePatternSet &patterns,
1155 ConversionTarget &target) {
1156 typeConverter.addConversion([&](TokenType type) { return type; });
1157 typeConverter.addConversion(callback: [&](ValueType type) {
1158 Type converted = typeConverter.convertType(type.getValueType());
1159 return converted ? ValueType::get(converted) : converted;
1160 });
1161
1162 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1163 arg&: typeConverter, args: patterns.getContext());
1164
1165 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1166 [&](Operation *op) { return typeConverter.isLegal(op); });
1167}
1168

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp