1 | //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Async/IR/Async.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
20 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
23 | #include "mlir/IR/TypeUtilities.h" |
24 | #include "mlir/Pass/Pass.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | #define DEBUG_TYPE "convert-async-to-llvm" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::async; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Async Runtime C API declaration. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef" ; |
43 | static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef" ; |
44 | static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken" ; |
45 | static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue" ; |
46 | static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup" ; |
47 | static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken" ; |
48 | static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue" ; |
49 | static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError" ; |
50 | static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError" ; |
51 | static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError" ; |
52 | static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError" ; |
53 | static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError" ; |
54 | static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken" ; |
55 | static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue" ; |
56 | static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup" ; |
57 | static constexpr const char *kExecute = "mlirAsyncRuntimeExecute" ; |
58 | static constexpr const char *kGetValueStorage = |
59 | "mlirAsyncRuntimeGetValueStorage" ; |
60 | static constexpr const char *kAddTokenToGroup = |
61 | "mlirAsyncRuntimeAddTokenToGroup" ; |
62 | static constexpr const char *kAwaitTokenAndExecute = |
63 | "mlirAsyncRuntimeAwaitTokenAndExecute" ; |
64 | static constexpr const char *kAwaitValueAndExecute = |
65 | "mlirAsyncRuntimeAwaitValueAndExecute" ; |
66 | static constexpr const char *kAwaitAllAndExecute = |
67 | "mlirAsyncRuntimeAwaitAllInGroupAndExecute" ; |
68 | static constexpr const char *kGetNumWorkerThreads = |
69 | "mlirAsyncRuntimGetNumWorkerThreads" ; |
70 | |
71 | namespace { |
72 | /// Async Runtime API function types. |
73 | /// |
74 | /// Because we can't create API function signature for type parametrized |
75 | /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After |
76 | /// lowering all async data types become opaque pointers at runtime. |
77 | struct AsyncAPI { |
78 | // All async types are lowered to opaque LLVM pointers at runtime. |
79 | static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { |
80 | return LLVM::LLVMPointerType::get(ctx); |
81 | } |
82 | |
83 | static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { |
84 | return LLVM::LLVMTokenType::get(ctx); |
85 | } |
86 | |
87 | static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { |
88 | auto ref = opaquePointerType(ctx); |
89 | auto count = IntegerType::get(ctx, 64); |
90 | return FunctionType::get(ctx, {ref, count}, {}); |
91 | } |
92 | |
93 | static FunctionType createTokenFunctionType(MLIRContext *ctx) { |
94 | return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); |
95 | } |
96 | |
97 | static FunctionType createValueFunctionType(MLIRContext *ctx) { |
98 | auto i64 = IntegerType::get(ctx, 64); |
99 | auto value = opaquePointerType(ctx); |
100 | return FunctionType::get(ctx, {i64}, {value}); |
101 | } |
102 | |
103 | static FunctionType createGroupFunctionType(MLIRContext *ctx) { |
104 | auto i64 = IntegerType::get(ctx, 64); |
105 | return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); |
106 | } |
107 | |
108 | static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { |
109 | auto ptrType = opaquePointerType(ctx); |
110 | return FunctionType::get(ctx, {ptrType}, {ptrType}); |
111 | } |
112 | |
113 | static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { |
114 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
115 | } |
116 | |
117 | static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { |
118 | auto value = opaquePointerType(ctx); |
119 | return FunctionType::get(ctx, {value}, {}); |
120 | } |
121 | |
122 | static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { |
123 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
124 | } |
125 | |
126 | static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { |
127 | auto value = opaquePointerType(ctx); |
128 | return FunctionType::get(ctx, {value}, {}); |
129 | } |
130 | |
131 | static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { |
132 | auto i1 = IntegerType::get(ctx, 1); |
133 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); |
134 | } |
135 | |
136 | static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { |
137 | auto value = opaquePointerType(ctx); |
138 | auto i1 = IntegerType::get(ctx, 1); |
139 | return FunctionType::get(ctx, {value}, {i1}); |
140 | } |
141 | |
142 | static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { |
143 | auto i1 = IntegerType::get(ctx, 1); |
144 | return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); |
145 | } |
146 | |
147 | static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { |
148 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
149 | } |
150 | |
151 | static FunctionType awaitValueFunctionType(MLIRContext *ctx) { |
152 | auto value = opaquePointerType(ctx); |
153 | return FunctionType::get(ctx, {value}, {}); |
154 | } |
155 | |
156 | static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { |
157 | return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); |
158 | } |
159 | |
160 | static FunctionType executeFunctionType(MLIRContext *ctx) { |
161 | auto ptrType = opaquePointerType(ctx); |
162 | return FunctionType::get(ctx, {ptrType, ptrType}, {}); |
163 | } |
164 | |
165 | static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { |
166 | auto i64 = IntegerType::get(ctx, 64); |
167 | return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, |
168 | {i64}); |
169 | } |
170 | |
171 | static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { |
172 | auto ptrType = opaquePointerType(ctx); |
173 | return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {}); |
174 | } |
175 | |
176 | static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { |
177 | auto ptrType = opaquePointerType(ctx); |
178 | return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {}); |
179 | } |
180 | |
181 | static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { |
182 | auto ptrType = opaquePointerType(ctx); |
183 | return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {}); |
184 | } |
185 | |
186 | static FunctionType getNumWorkerThreads(MLIRContext *ctx) { |
187 | return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); |
188 | } |
189 | |
190 | // Auxiliary coroutine resume intrinsic wrapper. |
191 | static Type resumeFunctionType(MLIRContext *ctx) { |
192 | auto voidTy = LLVM::LLVMVoidType::get(ctx); |
193 | auto ptrType = opaquePointerType(ctx); |
194 | return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); |
195 | } |
196 | }; |
197 | } // namespace |
198 | |
199 | /// Adds Async Runtime C API declarations to the module. |
200 | static void addAsyncRuntimeApiDeclarations(ModuleOp module) { |
201 | auto builder = |
202 | ImplicitLocOpBuilder::atBlockEnd(loc: module.getLoc(), block: module.getBody()); |
203 | |
204 | auto addFuncDecl = [&](StringRef name, FunctionType type) { |
205 | if (module.lookupSymbol(name)) |
206 | return; |
207 | builder.create<func::FuncOp>(name, type).setPrivate(); |
208 | }; |
209 | |
210 | MLIRContext *ctx = module.getContext(); |
211 | addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); |
212 | addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); |
213 | addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); |
214 | addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); |
215 | addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); |
216 | addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); |
217 | addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); |
218 | addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); |
219 | addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); |
220 | addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); |
221 | addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); |
222 | addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); |
223 | addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); |
224 | addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); |
225 | addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); |
226 | addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); |
227 | addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); |
228 | addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); |
229 | addFuncDecl(kAwaitTokenAndExecute, |
230 | AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); |
231 | addFuncDecl(kAwaitValueAndExecute, |
232 | AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); |
233 | addFuncDecl(kAwaitAllAndExecute, |
234 | AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); |
235 | addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); |
236 | } |
237 | |
238 | //===----------------------------------------------------------------------===// |
239 | // Coroutine resume function wrapper. |
240 | //===----------------------------------------------------------------------===// |
241 | |
242 | static constexpr const char *kResume = "__resume" ; |
243 | |
244 | /// A function that takes a coroutine handle and calls a `llvm.coro.resume` |
245 | /// intrinsics. We need this function to be able to pass it to the async |
246 | /// runtime execute API. |
247 | static void addResumeFunction(ModuleOp module) { |
248 | if (module.lookupSymbol(kResume)) |
249 | return; |
250 | |
251 | MLIRContext *ctx = module.getContext(); |
252 | auto loc = module.getLoc(); |
253 | auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: module.getBody()); |
254 | |
255 | auto voidTy = LLVM::LLVMVoidType::get(ctx); |
256 | Type ptrType = AsyncAPI::opaquePointerType(ctx); |
257 | |
258 | auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( |
259 | kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); |
260 | resumeOp.setPrivate(); |
261 | |
262 | auto *block = resumeOp.addEntryBlock(moduleBuilder); |
263 | auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: block); |
264 | |
265 | blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); |
266 | blockBuilder.create<LLVM::ReturnOp>(ValueRange()); |
267 | } |
268 | |
269 | //===----------------------------------------------------------------------===// |
270 | // Convert Async dialect types to LLVM types. |
271 | //===----------------------------------------------------------------------===// |
272 | |
273 | namespace { |
274 | /// AsyncRuntimeTypeConverter only converts types from the Async dialect to |
275 | /// their runtime type (opaque pointers) and does not convert any other types. |
276 | class AsyncRuntimeTypeConverter : public TypeConverter { |
277 | public: |
278 | AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) { |
279 | addConversion(callback: [](Type type) { return type; }); |
280 | addConversion(callback: [](Type type) { return convertAsyncTypes(type); }); |
281 | |
282 | // Use UnrealizedConversionCast as the bridge so that we don't need to pull |
283 | // in patterns for other dialects. |
284 | auto addUnrealizedCast = [](OpBuilder &builder, Type type, |
285 | ValueRange inputs, Location loc) { |
286 | auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); |
287 | return std::optional<Value>(cast.getResult(0)); |
288 | }; |
289 | |
290 | addSourceMaterialization(addUnrealizedCast); |
291 | addTargetMaterialization(addUnrealizedCast); |
292 | } |
293 | |
294 | static std::optional<Type> convertAsyncTypes(Type type) { |
295 | if (isa<TokenType, GroupType, ValueType>(type)) |
296 | return AsyncAPI::opaquePointerType(type.getContext()); |
297 | |
298 | if (isa<CoroIdType, CoroStateType>(type)) |
299 | return AsyncAPI::tokenType(ctx: type.getContext()); |
300 | if (isa<CoroHandleType>(type)) |
301 | return AsyncAPI::opaquePointerType(type.getContext()); |
302 | |
303 | return std::nullopt; |
304 | } |
305 | }; |
306 | |
307 | /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter |
308 | /// as type converter. Allows access to it via the 'getTypeConverter' |
309 | /// convenience method. |
310 | template <typename SourceOp> |
311 | class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> { |
312 | |
313 | using Base = OpConversionPattern<SourceOp>; |
314 | |
315 | public: |
316 | AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, |
317 | MLIRContext *context) |
318 | : Base(typeConverter, context) {} |
319 | |
320 | /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. |
321 | const AsyncRuntimeTypeConverter *getTypeConverter() const { |
322 | return static_cast<const AsyncRuntimeTypeConverter *>( |
323 | Base::getTypeConverter()); |
324 | } |
325 | }; |
326 | |
327 | } // namespace |
328 | |
329 | //===----------------------------------------------------------------------===// |
330 | // Convert async.coro.id to @llvm.coro.id intrinsic. |
331 | //===----------------------------------------------------------------------===// |
332 | |
333 | namespace { |
334 | class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> { |
335 | public: |
336 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
337 | |
338 | LogicalResult |
339 | matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, |
340 | ConversionPatternRewriter &rewriter) const override { |
341 | auto token = AsyncAPI::tokenType(ctx: op->getContext()); |
342 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
343 | auto loc = op->getLoc(); |
344 | |
345 | // Constants for initializing coroutine frame. |
346 | auto constZero = |
347 | rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
348 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); |
349 | |
350 | // Get coroutine id: @llvm.coro.id. |
351 | rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( |
352 | op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); |
353 | |
354 | return success(); |
355 | } |
356 | }; |
357 | } // namespace |
358 | |
359 | //===----------------------------------------------------------------------===// |
360 | // Convert async.coro.begin to @llvm.coro.begin intrinsic. |
361 | //===----------------------------------------------------------------------===// |
362 | |
363 | namespace { |
364 | class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> { |
365 | public: |
366 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
367 | |
368 | LogicalResult |
369 | matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, |
370 | ConversionPatternRewriter &rewriter) const override { |
371 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
372 | auto loc = op->getLoc(); |
373 | |
374 | // Get coroutine frame size: @llvm.coro.size.i64. |
375 | Value coroSize = |
376 | rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); |
377 | // Get coroutine frame alignment: @llvm.coro.align.i64. |
378 | Value coroAlign = |
379 | rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); |
380 | |
381 | // Round up the size to be multiple of the alignment. Since aligned_alloc |
382 | // requires the size parameter be an integral multiple of the alignment |
383 | // parameter. |
384 | auto makeConstant = [&](uint64_t c) { |
385 | return rewriter.create<LLVM::ConstantOp>(op->getLoc(), |
386 | rewriter.getI64Type(), c); |
387 | }; |
388 | coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); |
389 | coroSize = |
390 | rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); |
391 | Value negCoroAlign = |
392 | rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); |
393 | coroSize = |
394 | rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); |
395 | |
396 | // Allocate memory for the coroutine frame. |
397 | auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( |
398 | moduleOp: op->getParentOfType<ModuleOp>(), indexType: rewriter.getI64Type()); |
399 | auto coroAlloc = rewriter.create<LLVM::CallOp>( |
400 | loc, allocFuncOp, ValueRange{coroAlign, coroSize}); |
401 | |
402 | // Begin a coroutine: @llvm.coro.begin. |
403 | auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); |
404 | rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( |
405 | op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); |
406 | |
407 | return success(); |
408 | } |
409 | }; |
410 | } // namespace |
411 | |
412 | //===----------------------------------------------------------------------===// |
413 | // Convert async.coro.free to @llvm.coro.free intrinsic. |
414 | //===----------------------------------------------------------------------===// |
415 | |
416 | namespace { |
417 | class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> { |
418 | public: |
419 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
420 | |
421 | LogicalResult |
422 | matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, |
423 | ConversionPatternRewriter &rewriter) const override { |
424 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
425 | auto loc = op->getLoc(); |
426 | |
427 | // Get a pointer to the coroutine frame memory: @llvm.coro.free. |
428 | auto coroMem = |
429 | rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); |
430 | |
431 | // Free the memory. |
432 | auto freeFuncOp = |
433 | LLVM::lookupOrCreateFreeFn(moduleOp: op->getParentOfType<ModuleOp>()); |
434 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp, |
435 | ValueRange(coroMem.getResult())); |
436 | |
437 | return success(); |
438 | } |
439 | }; |
440 | } // namespace |
441 | |
442 | //===----------------------------------------------------------------------===// |
443 | // Convert async.coro.end to @llvm.coro.end intrinsic. |
444 | //===----------------------------------------------------------------------===// |
445 | |
446 | namespace { |
447 | class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { |
448 | public: |
449 | using OpConversionPattern::OpConversionPattern; |
450 | |
451 | LogicalResult |
452 | matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, |
453 | ConversionPatternRewriter &rewriter) const override { |
454 | // We are not in the block that is part of the unwind sequence. |
455 | auto constFalse = rewriter.create<LLVM::ConstantOp>( |
456 | op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); |
457 | auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); |
458 | |
459 | // Mark the end of a coroutine: @llvm.coro.end. |
460 | auto coroHdl = adaptor.getHandle(); |
461 | rewriter.create<LLVM::CoroEndOp>( |
462 | op->getLoc(), rewriter.getI1Type(), |
463 | ValueRange({coroHdl, constFalse, noneToken})); |
464 | rewriter.eraseOp(op: op); |
465 | |
466 | return success(); |
467 | } |
468 | }; |
469 | } // namespace |
470 | |
471 | //===----------------------------------------------------------------------===// |
472 | // Convert async.coro.save to @llvm.coro.save intrinsic. |
473 | //===----------------------------------------------------------------------===// |
474 | |
475 | namespace { |
476 | class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { |
477 | public: |
478 | using OpConversionPattern::OpConversionPattern; |
479 | |
480 | LogicalResult |
481 | matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, |
482 | ConversionPatternRewriter &rewriter) const override { |
483 | // Save the coroutine state: @llvm.coro.save |
484 | rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( |
485 | op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); |
486 | |
487 | return success(); |
488 | } |
489 | }; |
490 | } // namespace |
491 | |
492 | //===----------------------------------------------------------------------===// |
493 | // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. |
494 | //===----------------------------------------------------------------------===// |
495 | |
496 | namespace { |
497 | |
498 | /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and |
499 | /// branch to the appropriate block based on the return code. |
500 | /// |
501 | /// Before: |
502 | /// |
503 | /// ^suspended: |
504 | /// "opBefore"(...) |
505 | /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup |
506 | /// ^resume: |
507 | /// "op"(...) |
508 | /// ^cleanup: ... |
509 | /// ^suspend: ... |
510 | /// |
511 | /// After: |
512 | /// |
513 | /// ^suspended: |
514 | /// "opBefore"(...) |
515 | /// %suspend = llmv.intr.coro.suspend ... |
516 | /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] |
517 | /// ^resume: |
518 | /// "op"(...) |
519 | /// ^cleanup: ... |
520 | /// ^suspend: ... |
521 | /// |
522 | class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { |
523 | public: |
524 | using OpConversionPattern::OpConversionPattern; |
525 | |
526 | LogicalResult |
527 | matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, |
528 | ConversionPatternRewriter &rewriter) const override { |
529 | auto i8 = rewriter.getIntegerType(8); |
530 | auto i32 = rewriter.getI32Type(); |
531 | auto loc = op->getLoc(); |
532 | |
533 | // This is not a final suspension point. |
534 | auto constFalse = rewriter.create<LLVM::ConstantOp>( |
535 | loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); |
536 | |
537 | // Suspend a coroutine: @llvm.coro.suspend |
538 | auto coroState = adaptor.getState(); |
539 | auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( |
540 | loc, i8, ValueRange({coroState, constFalse})); |
541 | |
542 | // Cast return code to i32. |
543 | |
544 | // After a suspension point decide if we should branch into resume, cleanup |
545 | // or suspend block of the coroutine (see @llvm.coro.suspend return code |
546 | // documentation). |
547 | llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; |
548 | llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), |
549 | op.getCleanupDest()}; |
550 | rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
551 | op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), |
552 | /*defaultDestination=*/op.getSuspendDest(), |
553 | /*defaultOperands=*/ValueRange(), |
554 | /*caseValues=*/caseValues, |
555 | /*caseDestinations=*/caseDest, |
556 | /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), |
557 | /*branchWeights=*/ArrayRef<int32_t>()); |
558 | |
559 | return success(); |
560 | } |
561 | }; |
562 | } // namespace |
563 | |
564 | //===----------------------------------------------------------------------===// |
565 | // Convert async.runtime.create to the corresponding runtime API call. |
566 | // |
567 | // To allocate storage for the async values we use getelementptr trick: |
568 | // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt |
569 | //===----------------------------------------------------------------------===// |
570 | |
571 | namespace { |
572 | class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> { |
573 | public: |
574 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
575 | |
576 | LogicalResult |
577 | matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, |
578 | ConversionPatternRewriter &rewriter) const override { |
579 | const TypeConverter *converter = getTypeConverter(); |
580 | Type resultType = op->getResultTypes()[0]; |
581 | |
582 | // Tokens creation maps to a simple function call. |
583 | if (isa<TokenType>(resultType)) { |
584 | rewriter.replaceOpWithNewOp<func::CallOp>( |
585 | op, kCreateToken, converter->convertType(resultType)); |
586 | return success(); |
587 | } |
588 | |
589 | // To create a value we need to compute the storage requirement. |
590 | if (auto value = dyn_cast<ValueType>(resultType)) { |
591 | // Returns the size requirements for the async value storage. |
592 | auto sizeOf = [&](ValueType valueType) -> Value { |
593 | auto loc = op->getLoc(); |
594 | auto i64 = rewriter.getI64Type(); |
595 | |
596 | auto storedType = converter->convertType(valueType.getValueType()); |
597 | auto storagePtrType = |
598 | AsyncAPI::opaquePointerType(rewriter.getContext()); |
599 | |
600 | // %Size = getelementptr %T* null, int 1 |
601 | // %SizeI = ptrtoint %T* %Size to i64 |
602 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); |
603 | auto gep = |
604 | rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, |
605 | nullPtr, ArrayRef<LLVM::GEPArg>{1}); |
606 | return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); |
607 | }; |
608 | |
609 | rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, |
610 | sizeOf(value)); |
611 | |
612 | return success(); |
613 | } |
614 | |
615 | return rewriter.notifyMatchFailure(op, "unsupported async type" ); |
616 | } |
617 | }; |
618 | } // namespace |
619 | |
620 | //===----------------------------------------------------------------------===// |
621 | // Convert async.runtime.create_group to the corresponding runtime API call. |
622 | //===----------------------------------------------------------------------===// |
623 | |
624 | namespace { |
625 | class RuntimeCreateGroupOpLowering |
626 | : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> { |
627 | public: |
628 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
629 | |
630 | LogicalResult |
631 | matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, |
632 | ConversionPatternRewriter &rewriter) const override { |
633 | const TypeConverter *converter = getTypeConverter(); |
634 | Type resultType = op.getResult().getType(); |
635 | |
636 | rewriter.replaceOpWithNewOp<func::CallOp>( |
637 | op, kCreateGroup, converter->convertType(resultType), |
638 | adaptor.getOperands()); |
639 | return success(); |
640 | } |
641 | }; |
642 | } // namespace |
643 | |
644 | //===----------------------------------------------------------------------===// |
645 | // Convert async.runtime.set_available to the corresponding runtime API call. |
646 | //===----------------------------------------------------------------------===// |
647 | |
648 | namespace { |
649 | class RuntimeSetAvailableOpLowering |
650 | : public OpConversionPattern<RuntimeSetAvailableOp> { |
651 | public: |
652 | using OpConversionPattern::OpConversionPattern; |
653 | |
654 | LogicalResult |
655 | matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, |
656 | ConversionPatternRewriter &rewriter) const override { |
657 | StringRef apiFuncName = |
658 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
659 | .Case<TokenType>([](Type) { return kEmplaceToken; }) |
660 | .Case<ValueType>([](Type) { return kEmplaceValue; }); |
661 | |
662 | rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), |
663 | adaptor.getOperands()); |
664 | |
665 | return success(); |
666 | } |
667 | }; |
668 | } // namespace |
669 | |
670 | //===----------------------------------------------------------------------===// |
671 | // Convert async.runtime.set_error to the corresponding runtime API call. |
672 | //===----------------------------------------------------------------------===// |
673 | |
674 | namespace { |
675 | class RuntimeSetErrorOpLowering |
676 | : public OpConversionPattern<RuntimeSetErrorOp> { |
677 | public: |
678 | using OpConversionPattern::OpConversionPattern; |
679 | |
680 | LogicalResult |
681 | matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, |
682 | ConversionPatternRewriter &rewriter) const override { |
683 | StringRef apiFuncName = |
684 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
685 | .Case<TokenType>([](Type) { return kSetTokenError; }) |
686 | .Case<ValueType>([](Type) { return kSetValueError; }); |
687 | |
688 | rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), |
689 | adaptor.getOperands()); |
690 | |
691 | return success(); |
692 | } |
693 | }; |
694 | } // namespace |
695 | |
696 | //===----------------------------------------------------------------------===// |
697 | // Convert async.runtime.is_error to the corresponding runtime API call. |
698 | //===----------------------------------------------------------------------===// |
699 | |
700 | namespace { |
701 | class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { |
702 | public: |
703 | using OpConversionPattern::OpConversionPattern; |
704 | |
705 | LogicalResult |
706 | matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, |
707 | ConversionPatternRewriter &rewriter) const override { |
708 | StringRef apiFuncName = |
709 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
710 | .Case<TokenType>([](Type) { return kIsTokenError; }) |
711 | .Case<GroupType>([](Type) { return kIsGroupError; }) |
712 | .Case<ValueType>([](Type) { return kIsValueError; }); |
713 | |
714 | rewriter.replaceOpWithNewOp<func::CallOp>( |
715 | op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); |
716 | return success(); |
717 | } |
718 | }; |
719 | } // namespace |
720 | |
721 | //===----------------------------------------------------------------------===// |
722 | // Convert async.runtime.await to the corresponding runtime API call. |
723 | //===----------------------------------------------------------------------===// |
724 | |
725 | namespace { |
726 | class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { |
727 | public: |
728 | using OpConversionPattern::OpConversionPattern; |
729 | |
730 | LogicalResult |
731 | matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, |
732 | ConversionPatternRewriter &rewriter) const override { |
733 | StringRef apiFuncName = |
734 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
735 | .Case<TokenType>([](Type) { return kAwaitToken; }) |
736 | .Case<ValueType>([](Type) { return kAwaitValue; }) |
737 | .Case<GroupType>([](Type) { return kAwaitGroup; }); |
738 | |
739 | rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), |
740 | adaptor.getOperands()); |
741 | rewriter.eraseOp(op: op); |
742 | |
743 | return success(); |
744 | } |
745 | }; |
746 | } // namespace |
747 | |
748 | //===----------------------------------------------------------------------===// |
749 | // Convert async.runtime.await_and_resume to the corresponding runtime API call. |
750 | //===----------------------------------------------------------------------===// |
751 | |
752 | namespace { |
753 | class RuntimeAwaitAndResumeOpLowering |
754 | : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> { |
755 | public: |
756 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
757 | |
758 | LogicalResult |
759 | matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, |
760 | ConversionPatternRewriter &rewriter) const override { |
761 | StringRef apiFuncName = |
762 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
763 | .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) |
764 | .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) |
765 | .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); |
766 | |
767 | Value operand = adaptor.getOperand(); |
768 | Value handle = adaptor.getHandle(); |
769 | |
770 | // A pointer to coroutine resume intrinsic wrapper. |
771 | addResumeFunction(op->getParentOfType<ModuleOp>()); |
772 | auto resumePtr = rewriter.create<LLVM::AddressOfOp>( |
773 | op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), |
774 | kResume); |
775 | |
776 | rewriter.create<func::CallOp>( |
777 | op->getLoc(), apiFuncName, TypeRange(), |
778 | ValueRange({operand, handle, resumePtr.getRes()})); |
779 | rewriter.eraseOp(op: op); |
780 | |
781 | return success(); |
782 | } |
783 | }; |
784 | } // namespace |
785 | |
786 | //===----------------------------------------------------------------------===// |
787 | // Convert async.runtime.resume to the corresponding runtime API call. |
788 | //===----------------------------------------------------------------------===// |
789 | |
790 | namespace { |
791 | class RuntimeResumeOpLowering |
792 | : public AsyncOpConversionPattern<RuntimeResumeOp> { |
793 | public: |
794 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
795 | |
796 | LogicalResult |
797 | matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, |
798 | ConversionPatternRewriter &rewriter) const override { |
799 | // A pointer to coroutine resume intrinsic wrapper. |
800 | addResumeFunction(op->getParentOfType<ModuleOp>()); |
801 | auto resumePtr = rewriter.create<LLVM::AddressOfOp>( |
802 | op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), |
803 | kResume); |
804 | |
805 | // Call async runtime API to execute a coroutine in the managed thread. |
806 | auto coroHdl = adaptor.getHandle(); |
807 | rewriter.replaceOpWithNewOp<func::CallOp>( |
808 | op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); |
809 | |
810 | return success(); |
811 | } |
812 | }; |
813 | } // namespace |
814 | |
815 | //===----------------------------------------------------------------------===// |
816 | // Convert async.runtime.store to the corresponding runtime API call. |
817 | //===----------------------------------------------------------------------===// |
818 | |
819 | namespace { |
820 | class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> { |
821 | public: |
822 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
823 | |
824 | LogicalResult |
825 | matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, |
826 | ConversionPatternRewriter &rewriter) const override { |
827 | Location loc = op->getLoc(); |
828 | |
829 | // Get a pointer to the async value storage from the runtime. |
830 | auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); |
831 | auto storage = adaptor.getStorage(); |
832 | auto storagePtr = rewriter.create<func::CallOp>( |
833 | loc, kGetValueStorage, TypeRange(ptrType), storage); |
834 | |
835 | // Cast from i8* to the LLVM pointer type. |
836 | auto valueType = op.getValue().getType(); |
837 | auto llvmValueType = getTypeConverter()->convertType(valueType); |
838 | if (!llvmValueType) |
839 | return rewriter.notifyMatchFailure( |
840 | op, "failed to convert stored value type to LLVM type" ); |
841 | |
842 | Value castedStoragePtr = storagePtr.getResult(0); |
843 | // Store the yielded value into the async value storage. |
844 | auto value = adaptor.getValue(); |
845 | rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); |
846 | |
847 | // Erase the original runtime store operation. |
848 | rewriter.eraseOp(op: op); |
849 | |
850 | return success(); |
851 | } |
852 | }; |
853 | } // namespace |
854 | |
855 | //===----------------------------------------------------------------------===// |
856 | // Convert async.runtime.load to the corresponding runtime API call. |
857 | //===----------------------------------------------------------------------===// |
858 | |
859 | namespace { |
860 | class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> { |
861 | public: |
862 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
863 | |
864 | LogicalResult |
865 | matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, |
866 | ConversionPatternRewriter &rewriter) const override { |
867 | Location loc = op->getLoc(); |
868 | |
869 | // Get a pointer to the async value storage from the runtime. |
870 | auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); |
871 | auto storage = adaptor.getStorage(); |
872 | auto storagePtr = rewriter.create<func::CallOp>( |
873 | loc, kGetValueStorage, TypeRange(ptrType), storage); |
874 | |
875 | // Cast from i8* to the LLVM pointer type. |
876 | auto valueType = op.getResult().getType(); |
877 | auto llvmValueType = getTypeConverter()->convertType(valueType); |
878 | if (!llvmValueType) |
879 | return rewriter.notifyMatchFailure( |
880 | op, "failed to convert loaded value type to LLVM type" ); |
881 | |
882 | Value castedStoragePtr = storagePtr.getResult(0); |
883 | |
884 | // Load from the casted pointer. |
885 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType, |
886 | castedStoragePtr); |
887 | |
888 | return success(); |
889 | } |
890 | }; |
891 | } // namespace |
892 | |
893 | //===----------------------------------------------------------------------===// |
894 | // Convert async.runtime.add_to_group to the corresponding runtime API call. |
895 | //===----------------------------------------------------------------------===// |
896 | |
897 | namespace { |
898 | class RuntimeAddToGroupOpLowering |
899 | : public OpConversionPattern<RuntimeAddToGroupOp> { |
900 | public: |
901 | using OpConversionPattern::OpConversionPattern; |
902 | |
903 | LogicalResult |
904 | matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, |
905 | ConversionPatternRewriter &rewriter) const override { |
906 | // Currently we can only add tokens to the group. |
907 | if (!isa<TokenType>(op.getOperand().getType())) |
908 | return rewriter.notifyMatchFailure(op, "only token type is supported" ); |
909 | |
910 | // Replace with a runtime API function call. |
911 | rewriter.replaceOpWithNewOp<func::CallOp>( |
912 | op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); |
913 | |
914 | return success(); |
915 | } |
916 | }; |
917 | } // namespace |
918 | |
919 | //===----------------------------------------------------------------------===// |
920 | // Convert async.runtime.num_worker_threads to the corresponding runtime API |
921 | // call. |
922 | //===----------------------------------------------------------------------===// |
923 | |
924 | namespace { |
925 | class |
926 | : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { |
927 | public: |
928 | using OpConversionPattern::OpConversionPattern; |
929 | |
930 | LogicalResult |
931 | matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, |
932 | ConversionPatternRewriter &rewriter) const override { |
933 | |
934 | // Replace with a runtime API function call. |
935 | rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads, |
936 | rewriter.getIndexType()); |
937 | |
938 | return success(); |
939 | } |
940 | }; |
941 | } // namespace |
942 | |
943 | //===----------------------------------------------------------------------===// |
944 | // Async reference counting ops lowering (`async.runtime.add_ref` and |
945 | // `async.runtime.drop_ref` to the corresponding API calls). |
946 | //===----------------------------------------------------------------------===// |
947 | |
948 | namespace { |
949 | template <typename RefCountingOp> |
950 | class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { |
951 | public: |
952 | explicit RefCountingOpLowering(const TypeConverter &converter, |
953 | MLIRContext *ctx, StringRef apiFunctionName) |
954 | : OpConversionPattern<RefCountingOp>(converter, ctx), |
955 | apiFunctionName(apiFunctionName) {} |
956 | |
957 | LogicalResult |
958 | matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, |
959 | ConversionPatternRewriter &rewriter) const override { |
960 | auto count = rewriter.create<arith::ConstantOp>( |
961 | op->getLoc(), rewriter.getI64Type(), |
962 | rewriter.getI64IntegerAttr(op.getCount())); |
963 | |
964 | auto operand = adaptor.getOperand(); |
965 | rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, |
966 | ValueRange({operand, count})); |
967 | |
968 | return success(); |
969 | } |
970 | |
971 | private: |
972 | StringRef apiFunctionName; |
973 | }; |
974 | |
975 | class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { |
976 | public: |
977 | explicit RuntimeAddRefOpLowering(const TypeConverter &converter, |
978 | MLIRContext *ctx) |
979 | : RefCountingOpLowering(converter, ctx, kAddRef) {} |
980 | }; |
981 | |
982 | class RuntimeDropRefOpLowering |
983 | : public RefCountingOpLowering<RuntimeDropRefOp> { |
984 | public: |
985 | explicit RuntimeDropRefOpLowering(const TypeConverter &converter, |
986 | MLIRContext *ctx) |
987 | : RefCountingOpLowering(converter, ctx, kDropRef) {} |
988 | }; |
989 | } // namespace |
990 | |
991 | //===----------------------------------------------------------------------===// |
992 | // Convert return operations that return async values from async regions. |
993 | //===----------------------------------------------------------------------===// |
994 | |
995 | namespace { |
996 | class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { |
997 | public: |
998 | using OpConversionPattern::OpConversionPattern; |
999 | |
1000 | LogicalResult |
1001 | matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, |
1002 | ConversionPatternRewriter &rewriter) const override { |
1003 | rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); |
1004 | return success(); |
1005 | } |
1006 | }; |
1007 | } // namespace |
1008 | |
1009 | //===----------------------------------------------------------------------===// |
1010 | |
1011 | namespace { |
1012 | struct ConvertAsyncToLLVMPass |
1013 | : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> { |
1014 | using Base::Base; |
1015 | |
1016 | void runOnOperation() override; |
1017 | }; |
1018 | } // namespace |
1019 | |
1020 | void ConvertAsyncToLLVMPass::runOnOperation() { |
1021 | ModuleOp module = getOperation(); |
1022 | MLIRContext *ctx = module->getContext(); |
1023 | |
1024 | LowerToLLVMOptions options(ctx); |
1025 | |
1026 | // Add declarations for most functions required by the coroutines lowering. |
1027 | // We delay adding the resume function until it's needed because it currently |
1028 | // fails to compile unless '-O0' is specified. |
1029 | addAsyncRuntimeApiDeclarations(module); |
1030 | |
1031 | // Lower async.runtime and async.coro operations to Async Runtime API and |
1032 | // LLVM coroutine intrinsics. |
1033 | |
1034 | // Convert async dialect types and operations to LLVM dialect. |
1035 | AsyncRuntimeTypeConverter converter(options); |
1036 | RewritePatternSet patterns(ctx); |
1037 | |
1038 | // We use conversion to LLVM type to lower async.runtime load and store |
1039 | // operations. |
1040 | LLVMTypeConverter llvmConverter(ctx, options); |
1041 | llvmConverter.addConversion(callback: [&](Type type) { |
1042 | return AsyncRuntimeTypeConverter::convertAsyncTypes(type); |
1043 | }); |
1044 | |
1045 | // Convert async types in function signatures and function calls. |
1046 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
1047 | converter); |
1048 | populateCallOpTypeConversionPattern(patterns, converter); |
1049 | |
1050 | // Convert return operations inside async.execute regions. |
1051 | patterns.add<ReturnOpOpConversion>(arg&: converter, args&: ctx); |
1052 | |
1053 | // Lower async.runtime operations to the async runtime API calls. |
1054 | patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, |
1055 | RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, |
1056 | RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, |
1057 | RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, |
1058 | RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(arg&: converter, |
1059 | args&: ctx); |
1060 | |
1061 | // Lower async.runtime operations that rely on LLVM type converter to convert |
1062 | // from async value payload type to the LLVM type. |
1063 | patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, |
1064 | RuntimeStoreOpLowering, RuntimeLoadOpLowering>(arg&: llvmConverter); |
1065 | |
1066 | // Lower async coroutine operations to LLVM coroutine intrinsics. |
1067 | patterns |
1068 | .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, |
1069 | CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( |
1070 | arg&: converter, args&: ctx); |
1071 | |
1072 | ConversionTarget target(*ctx); |
1073 | target.addLegalOp<arith::ConstantOp, func::ConstantOp, |
1074 | UnrealizedConversionCastOp>(); |
1075 | target.addLegalDialect<LLVM::LLVMDialect>(); |
1076 | |
1077 | // All operations from Async dialect must be lowered to the runtime API and |
1078 | // LLVM intrinsics calls. |
1079 | target.addIllegalDialect<AsyncDialect>(); |
1080 | |
1081 | // Add dynamic legality constraints to apply conversions defined above. |
1082 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1083 | return converter.isSignatureLegal(op.getFunctionType()); |
1084 | }); |
1085 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
1086 | return converter.isLegal(op.getOperandTypes()); |
1087 | }); |
1088 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
1089 | return converter.isSignatureLegal(op.getCalleeType()); |
1090 | }); |
1091 | |
1092 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
1093 | signalPassFailure(); |
1094 | } |
1095 | |
1096 | //===----------------------------------------------------------------------===// |
1097 | // Patterns for structural type conversions for the Async dialect operations. |
1098 | //===----------------------------------------------------------------------===// |
1099 | |
1100 | namespace { |
1101 | class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { |
1102 | public: |
1103 | using OpConversionPattern::OpConversionPattern; |
1104 | LogicalResult |
1105 | matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, |
1106 | ConversionPatternRewriter &rewriter) const override { |
1107 | ExecuteOp newOp = |
1108 | cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); |
1109 | rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), |
1110 | newOp.getRegion().end()); |
1111 | |
1112 | // Set operands and update block argument and result types. |
1113 | newOp->setOperands(adaptor.getOperands()); |
1114 | if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) |
1115 | return failure(); |
1116 | for (auto result : newOp.getResults()) |
1117 | result.setType(typeConverter->convertType(result.getType())); |
1118 | |
1119 | rewriter.replaceOp(op, newOp.getResults()); |
1120 | return success(); |
1121 | } |
1122 | }; |
1123 | |
1124 | // Dummy pattern to trigger the appropriate type conversion / materialization. |
1125 | class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { |
1126 | public: |
1127 | using OpConversionPattern::OpConversionPattern; |
1128 | LogicalResult |
1129 | matchAndRewrite(AwaitOp op, OpAdaptor adaptor, |
1130 | ConversionPatternRewriter &rewriter) const override { |
1131 | rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); |
1132 | return success(); |
1133 | } |
1134 | }; |
1135 | |
1136 | // Dummy pattern to trigger the appropriate type conversion / materialization. |
1137 | class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { |
1138 | public: |
1139 | using OpConversionPattern::OpConversionPattern; |
1140 | LogicalResult |
1141 | matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, |
1142 | ConversionPatternRewriter &rewriter) const override { |
1143 | rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); |
1144 | return success(); |
1145 | } |
1146 | }; |
1147 | } // namespace |
1148 | |
1149 | void mlir::populateAsyncStructuralTypeConversionsAndLegality( |
1150 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
1151 | ConversionTarget &target) { |
1152 | typeConverter.addConversion([&](TokenType type) { return type; }); |
1153 | typeConverter.addConversion(callback: [&](ValueType type) { |
1154 | Type converted = typeConverter.convertType(type.getValueType()); |
1155 | return converted ? ValueType::get(converted) : converted; |
1156 | }); |
1157 | |
1158 | patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( |
1159 | arg&: typeConverter, args: patterns.getContext()); |
1160 | |
1161 | target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( |
1162 | [&](Operation *op) { return typeConverter.isLegal(op); }); |
1163 | } |
1164 | |