1 | //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Async/IR/Async.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
20 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
23 | #include "mlir/IR/TypeUtilities.h" |
24 | #include "mlir/Pass/Pass.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | #define DEBUG_TYPE "convert-async-to-llvm" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::async; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Async Runtime C API declaration. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; |
43 | static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; |
44 | static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; |
45 | static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; |
46 | static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; |
47 | static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; |
48 | static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; |
49 | static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; |
50 | static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; |
51 | static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; |
52 | static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; |
53 | static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; |
54 | static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; |
55 | static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; |
56 | static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; |
57 | static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; |
58 | static constexpr const char *kGetValueStorage = |
59 | "mlirAsyncRuntimeGetValueStorage"; |
60 | static constexpr const char *kAddTokenToGroup = |
61 | "mlirAsyncRuntimeAddTokenToGroup"; |
62 | static constexpr const char *kAwaitTokenAndExecute = |
63 | "mlirAsyncRuntimeAwaitTokenAndExecute"; |
64 | static constexpr const char *kAwaitValueAndExecute = |
65 | "mlirAsyncRuntimeAwaitValueAndExecute"; |
66 | static constexpr const char *kAwaitAllAndExecute = |
67 | "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; |
68 | static constexpr const char *kGetNumWorkerThreads = |
69 | "mlirAsyncRuntimGetNumWorkerThreads"; |
70 | |
71 | namespace { |
72 | /// Async Runtime API function types. |
73 | /// |
74 | /// Because we can't create API function signature for type parametrized |
75 | /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After |
76 | /// lowering all async data types become opaque pointers at runtime. |
77 | struct AsyncAPI { |
78 | // All async types are lowered to opaque LLVM pointers at runtime. |
79 | static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { |
80 | return LLVM::LLVMPointerType::get(ctx); |
81 | } |
82 | |
83 | static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { |
84 | return LLVM::LLVMTokenType::get(ctx); |
85 | } |
86 | |
87 | static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { |
88 | auto ref = opaquePointerType(ctx); |
89 | auto count = IntegerType::get(ctx, 64); |
90 | return FunctionType::get(ctx, {ref, count}, {}); |
91 | } |
92 | |
93 | static FunctionType createTokenFunctionType(MLIRContext *ctx) { |
94 | return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); |
95 | } |
96 | |
97 | static FunctionType createValueFunctionType(MLIRContext *ctx) { |
98 | auto i64 = IntegerType::get(ctx, 64); |
99 | auto value = opaquePointerType(ctx); |
100 | return FunctionType::get(ctx, {i64}, {value}); |
101 | } |
102 | |
103 | static FunctionType createGroupFunctionType(MLIRContext *ctx) { |
104 | auto i64 = IntegerType::get(ctx, 64); |
105 | return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); |
106 | } |
107 | |
108 | static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { |
109 | auto ptrType = opaquePointerType(ctx); |
110 | return FunctionType::get(ctx, {ptrType}, {ptrType}); |
111 | } |
112 | |
113 | static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { |
114 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
115 | } |
116 | |
117 | static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { |
118 | auto value = opaquePointerType(ctx); |
119 | return FunctionType::get(ctx, {value}, {}); |
120 | } |
121 | |
122 | static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { |
123 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
124 | } |
125 | |
126 | static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { |
127 | auto value = opaquePointerType(ctx); |
128 | return FunctionType::get(ctx, {value}, {}); |
129 | } |
130 | |
131 | static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { |
132 | auto i1 = IntegerType::get(ctx, 1); |
133 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); |
134 | } |
135 | |
136 | static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { |
137 | auto value = opaquePointerType(ctx); |
138 | auto i1 = IntegerType::get(ctx, 1); |
139 | return FunctionType::get(ctx, {value}, {i1}); |
140 | } |
141 | |
142 | static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { |
143 | auto i1 = IntegerType::get(ctx, 1); |
144 | return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); |
145 | } |
146 | |
147 | static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { |
148 | return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); |
149 | } |
150 | |
151 | static FunctionType awaitValueFunctionType(MLIRContext *ctx) { |
152 | auto value = opaquePointerType(ctx); |
153 | return FunctionType::get(ctx, {value}, {}); |
154 | } |
155 | |
156 | static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { |
157 | return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); |
158 | } |
159 | |
160 | static FunctionType executeFunctionType(MLIRContext *ctx) { |
161 | auto ptrType = opaquePointerType(ctx); |
162 | return FunctionType::get(ctx, {ptrType, ptrType}, {}); |
163 | } |
164 | |
165 | static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { |
166 | auto i64 = IntegerType::get(ctx, 64); |
167 | return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, |
168 | {i64}); |
169 | } |
170 | |
171 | static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { |
172 | auto ptrType = opaquePointerType(ctx); |
173 | return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {}); |
174 | } |
175 | |
176 | static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { |
177 | auto ptrType = opaquePointerType(ctx); |
178 | return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {}); |
179 | } |
180 | |
181 | static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { |
182 | auto ptrType = opaquePointerType(ctx); |
183 | return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {}); |
184 | } |
185 | |
186 | static FunctionType getNumWorkerThreads(MLIRContext *ctx) { |
187 | return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); |
188 | } |
189 | |
190 | // Auxiliary coroutine resume intrinsic wrapper. |
191 | static Type resumeFunctionType(MLIRContext *ctx) { |
192 | auto voidTy = LLVM::LLVMVoidType::get(ctx); |
193 | auto ptrType = opaquePointerType(ctx); |
194 | return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); |
195 | } |
196 | }; |
197 | } // namespace |
198 | |
199 | /// Adds Async Runtime C API declarations to the module. |
200 | static void addAsyncRuntimeApiDeclarations(ModuleOp module) { |
201 | auto builder = |
202 | ImplicitLocOpBuilder::atBlockEnd(loc: module.getLoc(), block: module.getBody()); |
203 | |
204 | auto addFuncDecl = [&](StringRef name, FunctionType type) { |
205 | if (module.lookupSymbol(name)) |
206 | return; |
207 | builder.create<func::FuncOp>(name, type).setPrivate(); |
208 | }; |
209 | |
210 | MLIRContext *ctx = module.getContext(); |
211 | addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); |
212 | addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); |
213 | addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); |
214 | addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); |
215 | addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); |
216 | addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); |
217 | addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); |
218 | addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); |
219 | addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); |
220 | addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); |
221 | addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); |
222 | addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); |
223 | addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); |
224 | addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); |
225 | addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); |
226 | addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); |
227 | addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); |
228 | addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); |
229 | addFuncDecl(kAwaitTokenAndExecute, |
230 | AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); |
231 | addFuncDecl(kAwaitValueAndExecute, |
232 | AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); |
233 | addFuncDecl(kAwaitAllAndExecute, |
234 | AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); |
235 | addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); |
236 | } |
237 | |
238 | //===----------------------------------------------------------------------===// |
239 | // Coroutine resume function wrapper. |
240 | //===----------------------------------------------------------------------===// |
241 | |
242 | static constexpr const char *kResume = "__resume"; |
243 | |
244 | /// A function that takes a coroutine handle and calls a `llvm.coro.resume` |
245 | /// intrinsics. We need this function to be able to pass it to the async |
246 | /// runtime execute API. |
247 | static void addResumeFunction(ModuleOp module) { |
248 | if (module.lookupSymbol(kResume)) |
249 | return; |
250 | |
251 | MLIRContext *ctx = module.getContext(); |
252 | auto loc = module.getLoc(); |
253 | auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: module.getBody()); |
254 | |
255 | auto voidTy = LLVM::LLVMVoidType::get(ctx); |
256 | Type ptrType = AsyncAPI::opaquePointerType(ctx); |
257 | |
258 | auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( |
259 | kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); |
260 | resumeOp.setPrivate(); |
261 | |
262 | auto *block = resumeOp.addEntryBlock(moduleBuilder); |
263 | auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc: loc, block: block); |
264 | |
265 | blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); |
266 | blockBuilder.create<LLVM::ReturnOp>(ValueRange()); |
267 | } |
268 | |
269 | //===----------------------------------------------------------------------===// |
270 | // Convert Async dialect types to LLVM types. |
271 | //===----------------------------------------------------------------------===// |
272 | |
273 | namespace { |
274 | /// AsyncRuntimeTypeConverter only converts types from the Async dialect to |
275 | /// their runtime type (opaque pointers) and does not convert any other types. |
276 | class AsyncRuntimeTypeConverter : public TypeConverter { |
277 | public: |
278 | AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) { |
279 | addConversion(callback: [](Type type) { return type; }); |
280 | addConversion(callback: [](Type type) { return convertAsyncTypes(type); }); |
281 | |
282 | // Use UnrealizedConversionCast as the bridge so that we don't need to pull |
283 | // in patterns for other dialects. |
284 | auto addUnrealizedCast = [](OpBuilder &builder, Type type, |
285 | ValueRange inputs, Location loc) -> Value { |
286 | auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); |
287 | return cast.getResult(0); |
288 | }; |
289 | |
290 | addSourceMaterialization(callback&: addUnrealizedCast); |
291 | addTargetMaterialization(callback&: addUnrealizedCast); |
292 | } |
293 | |
294 | static std::optional<Type> convertAsyncTypes(Type type) { |
295 | if (isa<TokenType, GroupType, ValueType>(type)) |
296 | return AsyncAPI::opaquePointerType(type.getContext()); |
297 | |
298 | if (isa<CoroIdType, CoroStateType>(type)) |
299 | return AsyncAPI::tokenType(ctx: type.getContext()); |
300 | if (isa<CoroHandleType>(type)) |
301 | return AsyncAPI::opaquePointerType(type.getContext()); |
302 | |
303 | return std::nullopt; |
304 | } |
305 | }; |
306 | |
307 | /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter |
308 | /// as type converter. Allows access to it via the 'getTypeConverter' |
309 | /// convenience method. |
310 | template <typename SourceOp> |
311 | class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> { |
312 | |
313 | using Base = OpConversionPattern<SourceOp>; |
314 | |
315 | public: |
316 | AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, |
317 | MLIRContext *context) |
318 | : Base(typeConverter, context) {} |
319 | |
320 | /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. |
321 | const AsyncRuntimeTypeConverter *getTypeConverter() const { |
322 | return static_cast<const AsyncRuntimeTypeConverter *>( |
323 | Base::getTypeConverter()); |
324 | } |
325 | }; |
326 | |
327 | } // namespace |
328 | |
329 | //===----------------------------------------------------------------------===// |
330 | // Convert async.coro.id to @llvm.coro.id intrinsic. |
331 | //===----------------------------------------------------------------------===// |
332 | |
333 | namespace { |
334 | class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> { |
335 | public: |
336 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
337 | |
338 | LogicalResult |
339 | matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, |
340 | ConversionPatternRewriter &rewriter) const override { |
341 | auto token = AsyncAPI::tokenType(ctx: op->getContext()); |
342 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
343 | auto loc = op->getLoc(); |
344 | |
345 | // Constants for initializing coroutine frame. |
346 | auto constZero = |
347 | rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
348 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); |
349 | |
350 | // Get coroutine id: @llvm.coro.id. |
351 | rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( |
352 | op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); |
353 | |
354 | return success(); |
355 | } |
356 | }; |
357 | } // namespace |
358 | |
359 | //===----------------------------------------------------------------------===// |
360 | // Convert async.coro.begin to @llvm.coro.begin intrinsic. |
361 | //===----------------------------------------------------------------------===// |
362 | |
363 | namespace { |
364 | class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> { |
365 | public: |
366 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
367 | |
368 | LogicalResult |
369 | matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, |
370 | ConversionPatternRewriter &rewriter) const override { |
371 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
372 | auto loc = op->getLoc(); |
373 | |
374 | // Get coroutine frame size: @llvm.coro.size.i64. |
375 | Value coroSize = |
376 | rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); |
377 | // Get coroutine frame alignment: @llvm.coro.align.i64. |
378 | Value coroAlign = |
379 | rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); |
380 | |
381 | // Round up the size to be multiple of the alignment. Since aligned_alloc |
382 | // requires the size parameter be an integral multiple of the alignment |
383 | // parameter. |
384 | auto makeConstant = [&](uint64_t c) { |
385 | return rewriter.create<LLVM::ConstantOp>(op->getLoc(), |
386 | rewriter.getI64Type(), c); |
387 | }; |
388 | coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); |
389 | coroSize = |
390 | rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); |
391 | Value negCoroAlign = |
392 | rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); |
393 | coroSize = |
394 | rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); |
395 | |
396 | // Allocate memory for the coroutine frame. |
397 | auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( |
398 | b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>(), indexType: rewriter.getI64Type()); |
399 | if (failed(allocFuncOp)) |
400 | return failure(); |
401 | auto coroAlloc = rewriter.create<LLVM::CallOp>( |
402 | loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); |
403 | |
404 | // Begin a coroutine: @llvm.coro.begin. |
405 | auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); |
406 | rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( |
407 | op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); |
408 | |
409 | return success(); |
410 | } |
411 | }; |
412 | } // namespace |
413 | |
414 | //===----------------------------------------------------------------------===// |
415 | // Convert async.coro.free to @llvm.coro.free intrinsic. |
416 | //===----------------------------------------------------------------------===// |
417 | |
418 | namespace { |
419 | class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> { |
420 | public: |
421 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
422 | |
423 | LogicalResult |
424 | matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, |
425 | ConversionPatternRewriter &rewriter) const override { |
426 | auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); |
427 | auto loc = op->getLoc(); |
428 | |
429 | // Get a pointer to the coroutine frame memory: @llvm.coro.free. |
430 | auto coroMem = |
431 | rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); |
432 | |
433 | // Free the memory. |
434 | auto freeFuncOp = |
435 | LLVM::lookupOrCreateFreeFn(b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>()); |
436 | if (failed(freeFuncOp)) |
437 | return failure(); |
438 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(), |
439 | ValueRange(coroMem.getResult())); |
440 | |
441 | return success(); |
442 | } |
443 | }; |
444 | } // namespace |
445 | |
446 | //===----------------------------------------------------------------------===// |
447 | // Convert async.coro.end to @llvm.coro.end intrinsic. |
448 | //===----------------------------------------------------------------------===// |
449 | |
450 | namespace { |
451 | class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { |
452 | public: |
453 | using OpConversionPattern::OpConversionPattern; |
454 | |
455 | LogicalResult |
456 | matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, |
457 | ConversionPatternRewriter &rewriter) const override { |
458 | // We are not in the block that is part of the unwind sequence. |
459 | auto constFalse = rewriter.create<LLVM::ConstantOp>( |
460 | op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); |
461 | auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); |
462 | |
463 | // Mark the end of a coroutine: @llvm.coro.end. |
464 | auto coroHdl = adaptor.getHandle(); |
465 | rewriter.create<LLVM::CoroEndOp>( |
466 | op->getLoc(), rewriter.getI1Type(), |
467 | ValueRange({coroHdl, constFalse, noneToken})); |
468 | rewriter.eraseOp(op: op); |
469 | |
470 | return success(); |
471 | } |
472 | }; |
473 | } // namespace |
474 | |
475 | //===----------------------------------------------------------------------===// |
476 | // Convert async.coro.save to @llvm.coro.save intrinsic. |
477 | //===----------------------------------------------------------------------===// |
478 | |
479 | namespace { |
480 | class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { |
481 | public: |
482 | using OpConversionPattern::OpConversionPattern; |
483 | |
484 | LogicalResult |
485 | matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, |
486 | ConversionPatternRewriter &rewriter) const override { |
487 | // Save the coroutine state: @llvm.coro.save |
488 | rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( |
489 | op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); |
490 | |
491 | return success(); |
492 | } |
493 | }; |
494 | } // namespace |
495 | |
496 | //===----------------------------------------------------------------------===// |
497 | // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. |
498 | //===----------------------------------------------------------------------===// |
499 | |
500 | namespace { |
501 | |
502 | /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and |
503 | /// branch to the appropriate block based on the return code. |
504 | /// |
505 | /// Before: |
506 | /// |
507 | /// ^suspended: |
508 | /// "opBefore"(...) |
509 | /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup |
510 | /// ^resume: |
511 | /// "op"(...) |
512 | /// ^cleanup: ... |
513 | /// ^suspend: ... |
514 | /// |
515 | /// After: |
516 | /// |
517 | /// ^suspended: |
518 | /// "opBefore"(...) |
519 | /// %suspend = llmv.intr.coro.suspend ... |
520 | /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] |
521 | /// ^resume: |
522 | /// "op"(...) |
523 | /// ^cleanup: ... |
524 | /// ^suspend: ... |
525 | /// |
526 | class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { |
527 | public: |
528 | using OpConversionPattern::OpConversionPattern; |
529 | |
530 | LogicalResult |
531 | matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, |
532 | ConversionPatternRewriter &rewriter) const override { |
533 | auto i8 = rewriter.getIntegerType(8); |
534 | auto i32 = rewriter.getI32Type(); |
535 | auto loc = op->getLoc(); |
536 | |
537 | // This is not a final suspension point. |
538 | auto constFalse = rewriter.create<LLVM::ConstantOp>( |
539 | loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); |
540 | |
541 | // Suspend a coroutine: @llvm.coro.suspend |
542 | auto coroState = adaptor.getState(); |
543 | auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( |
544 | loc, i8, ValueRange({coroState, constFalse})); |
545 | |
546 | // Cast return code to i32. |
547 | |
548 | // After a suspension point decide if we should branch into resume, cleanup |
549 | // or suspend block of the coroutine (see @llvm.coro.suspend return code |
550 | // documentation). |
551 | llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; |
552 | llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), |
553 | op.getCleanupDest()}; |
554 | rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
555 | op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), |
556 | /*defaultDestination=*/op.getSuspendDest(), |
557 | /*defaultOperands=*/ValueRange(), |
558 | /*caseValues=*/caseValues, |
559 | /*caseDestinations=*/caseDest, |
560 | /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), |
561 | /*branchWeights=*/ArrayRef<int32_t>()); |
562 | |
563 | return success(); |
564 | } |
565 | }; |
566 | } // namespace |
567 | |
568 | //===----------------------------------------------------------------------===// |
569 | // Convert async.runtime.create to the corresponding runtime API call. |
570 | // |
571 | // To allocate storage for the async values we use getelementptr trick: |
572 | // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt |
573 | //===----------------------------------------------------------------------===// |
574 | |
575 | namespace { |
576 | class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> { |
577 | public: |
578 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
579 | |
580 | LogicalResult |
581 | matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, |
582 | ConversionPatternRewriter &rewriter) const override { |
583 | const TypeConverter *converter = getTypeConverter(); |
584 | Type resultType = op->getResultTypes()[0]; |
585 | |
586 | // Tokens creation maps to a simple function call. |
587 | if (isa<TokenType>(resultType)) { |
588 | rewriter.replaceOpWithNewOp<func::CallOp>( |
589 | op, kCreateToken, converter->convertType(resultType)); |
590 | return success(); |
591 | } |
592 | |
593 | // To create a value we need to compute the storage requirement. |
594 | if (auto value = dyn_cast<ValueType>(resultType)) { |
595 | // Returns the size requirements for the async value storage. |
596 | auto sizeOf = [&](ValueType valueType) -> Value { |
597 | auto loc = op->getLoc(); |
598 | auto i64 = rewriter.getI64Type(); |
599 | |
600 | auto storedType = converter->convertType(valueType.getValueType()); |
601 | auto storagePtrType = |
602 | AsyncAPI::opaquePointerType(rewriter.getContext()); |
603 | |
604 | // %Size = getelementptr %T* null, int 1 |
605 | // %SizeI = ptrtoint %T* %Size to i64 |
606 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); |
607 | auto gep = |
608 | rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, |
609 | nullPtr, ArrayRef<LLVM::GEPArg>{1}); |
610 | return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); |
611 | }; |
612 | |
613 | rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, |
614 | sizeOf(value)); |
615 | |
616 | return success(); |
617 | } |
618 | |
619 | return rewriter.notifyMatchFailure(op, "unsupported async type"); |
620 | } |
621 | }; |
622 | } // namespace |
623 | |
624 | //===----------------------------------------------------------------------===// |
625 | // Convert async.runtime.create_group to the corresponding runtime API call. |
626 | //===----------------------------------------------------------------------===// |
627 | |
628 | namespace { |
629 | class RuntimeCreateGroupOpLowering |
630 | : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> { |
631 | public: |
632 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
633 | |
634 | LogicalResult |
635 | matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, |
636 | ConversionPatternRewriter &rewriter) const override { |
637 | const TypeConverter *converter = getTypeConverter(); |
638 | Type resultType = op.getResult().getType(); |
639 | |
640 | rewriter.replaceOpWithNewOp<func::CallOp>( |
641 | op, kCreateGroup, converter->convertType(resultType), |
642 | adaptor.getOperands()); |
643 | return success(); |
644 | } |
645 | }; |
646 | } // namespace |
647 | |
648 | //===----------------------------------------------------------------------===// |
649 | // Convert async.runtime.set_available to the corresponding runtime API call. |
650 | //===----------------------------------------------------------------------===// |
651 | |
652 | namespace { |
653 | class RuntimeSetAvailableOpLowering |
654 | : public OpConversionPattern<RuntimeSetAvailableOp> { |
655 | public: |
656 | using OpConversionPattern::OpConversionPattern; |
657 | |
658 | LogicalResult |
659 | matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, |
660 | ConversionPatternRewriter &rewriter) const override { |
661 | StringRef apiFuncName = |
662 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
663 | .Case<TokenType>([](Type) { return kEmplaceToken; }) |
664 | .Case<ValueType>([](Type) { return kEmplaceValue; }); |
665 | |
666 | rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), |
667 | adaptor.getOperands()); |
668 | |
669 | return success(); |
670 | } |
671 | }; |
672 | } // namespace |
673 | |
674 | //===----------------------------------------------------------------------===// |
675 | // Convert async.runtime.set_error to the corresponding runtime API call. |
676 | //===----------------------------------------------------------------------===// |
677 | |
678 | namespace { |
679 | class RuntimeSetErrorOpLowering |
680 | : public OpConversionPattern<RuntimeSetErrorOp> { |
681 | public: |
682 | using OpConversionPattern::OpConversionPattern; |
683 | |
684 | LogicalResult |
685 | matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, |
686 | ConversionPatternRewriter &rewriter) const override { |
687 | StringRef apiFuncName = |
688 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
689 | .Case<TokenType>([](Type) { return kSetTokenError; }) |
690 | .Case<ValueType>([](Type) { return kSetValueError; }); |
691 | |
692 | rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), |
693 | adaptor.getOperands()); |
694 | |
695 | return success(); |
696 | } |
697 | }; |
698 | } // namespace |
699 | |
700 | //===----------------------------------------------------------------------===// |
701 | // Convert async.runtime.is_error to the corresponding runtime API call. |
702 | //===----------------------------------------------------------------------===// |
703 | |
704 | namespace { |
705 | class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { |
706 | public: |
707 | using OpConversionPattern::OpConversionPattern; |
708 | |
709 | LogicalResult |
710 | matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, |
711 | ConversionPatternRewriter &rewriter) const override { |
712 | StringRef apiFuncName = |
713 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
714 | .Case<TokenType>([](Type) { return kIsTokenError; }) |
715 | .Case<GroupType>([](Type) { return kIsGroupError; }) |
716 | .Case<ValueType>([](Type) { return kIsValueError; }); |
717 | |
718 | rewriter.replaceOpWithNewOp<func::CallOp>( |
719 | op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); |
720 | return success(); |
721 | } |
722 | }; |
723 | } // namespace |
724 | |
725 | //===----------------------------------------------------------------------===// |
726 | // Convert async.runtime.await to the corresponding runtime API call. |
727 | //===----------------------------------------------------------------------===// |
728 | |
729 | namespace { |
730 | class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { |
731 | public: |
732 | using OpConversionPattern::OpConversionPattern; |
733 | |
734 | LogicalResult |
735 | matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, |
736 | ConversionPatternRewriter &rewriter) const override { |
737 | StringRef apiFuncName = |
738 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
739 | .Case<TokenType>([](Type) { return kAwaitToken; }) |
740 | .Case<ValueType>([](Type) { return kAwaitValue; }) |
741 | .Case<GroupType>([](Type) { return kAwaitGroup; }); |
742 | |
743 | rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), |
744 | adaptor.getOperands()); |
745 | rewriter.eraseOp(op: op); |
746 | |
747 | return success(); |
748 | } |
749 | }; |
750 | } // namespace |
751 | |
752 | //===----------------------------------------------------------------------===// |
753 | // Convert async.runtime.await_and_resume to the corresponding runtime API call. |
754 | //===----------------------------------------------------------------------===// |
755 | |
756 | namespace { |
757 | class RuntimeAwaitAndResumeOpLowering |
758 | : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> { |
759 | public: |
760 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
761 | |
762 | LogicalResult |
763 | matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, |
764 | ConversionPatternRewriter &rewriter) const override { |
765 | StringRef apiFuncName = |
766 | TypeSwitch<Type, StringRef>(op.getOperand().getType()) |
767 | .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) |
768 | .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) |
769 | .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); |
770 | |
771 | Value operand = adaptor.getOperand(); |
772 | Value handle = adaptor.getHandle(); |
773 | |
774 | // A pointer to coroutine resume intrinsic wrapper. |
775 | addResumeFunction(op->getParentOfType<ModuleOp>()); |
776 | auto resumePtr = rewriter.create<LLVM::AddressOfOp>( |
777 | op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), |
778 | kResume); |
779 | |
780 | rewriter.create<func::CallOp>( |
781 | op->getLoc(), apiFuncName, TypeRange(), |
782 | ValueRange({operand, handle, resumePtr.getRes()})); |
783 | rewriter.eraseOp(op: op); |
784 | |
785 | return success(); |
786 | } |
787 | }; |
788 | } // namespace |
789 | |
790 | //===----------------------------------------------------------------------===// |
791 | // Convert async.runtime.resume to the corresponding runtime API call. |
792 | //===----------------------------------------------------------------------===// |
793 | |
794 | namespace { |
795 | class RuntimeResumeOpLowering |
796 | : public AsyncOpConversionPattern<RuntimeResumeOp> { |
797 | public: |
798 | using AsyncOpConversionPattern::AsyncOpConversionPattern; |
799 | |
800 | LogicalResult |
801 | matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, |
802 | ConversionPatternRewriter &rewriter) const override { |
803 | // A pointer to coroutine resume intrinsic wrapper. |
804 | addResumeFunction(op->getParentOfType<ModuleOp>()); |
805 | auto resumePtr = rewriter.create<LLVM::AddressOfOp>( |
806 | op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), |
807 | kResume); |
808 | |
809 | // Call async runtime API to execute a coroutine in the managed thread. |
810 | auto coroHdl = adaptor.getHandle(); |
811 | rewriter.replaceOpWithNewOp<func::CallOp>( |
812 | op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); |
813 | |
814 | return success(); |
815 | } |
816 | }; |
817 | } // namespace |
818 | |
819 | //===----------------------------------------------------------------------===// |
820 | // Convert async.runtime.store to the corresponding runtime API call. |
821 | //===----------------------------------------------------------------------===// |
822 | |
823 | namespace { |
824 | class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> { |
825 | public: |
826 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
827 | |
828 | LogicalResult |
829 | matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, |
830 | ConversionPatternRewriter &rewriter) const override { |
831 | Location loc = op->getLoc(); |
832 | |
833 | // Get a pointer to the async value storage from the runtime. |
834 | auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); |
835 | auto storage = adaptor.getStorage(); |
836 | auto storagePtr = rewriter.create<func::CallOp>( |
837 | loc, kGetValueStorage, TypeRange(ptrType), storage); |
838 | |
839 | // Cast from i8* to the LLVM pointer type. |
840 | auto valueType = op.getValue().getType(); |
841 | auto llvmValueType = getTypeConverter()->convertType(valueType); |
842 | if (!llvmValueType) |
843 | return rewriter.notifyMatchFailure( |
844 | op, "failed to convert stored value type to LLVM type"); |
845 | |
846 | Value castedStoragePtr = storagePtr.getResult(0); |
847 | // Store the yielded value into the async value storage. |
848 | auto value = adaptor.getValue(); |
849 | rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); |
850 | |
851 | // Erase the original runtime store operation. |
852 | rewriter.eraseOp(op: op); |
853 | |
854 | return success(); |
855 | } |
856 | }; |
857 | } // namespace |
858 | |
859 | //===----------------------------------------------------------------------===// |
860 | // Convert async.runtime.load to the corresponding runtime API call. |
861 | //===----------------------------------------------------------------------===// |
862 | |
863 | namespace { |
864 | class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> { |
865 | public: |
866 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
867 | |
868 | LogicalResult |
869 | matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, |
870 | ConversionPatternRewriter &rewriter) const override { |
871 | Location loc = op->getLoc(); |
872 | |
873 | // Get a pointer to the async value storage from the runtime. |
874 | auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); |
875 | auto storage = adaptor.getStorage(); |
876 | auto storagePtr = rewriter.create<func::CallOp>( |
877 | loc, kGetValueStorage, TypeRange(ptrType), storage); |
878 | |
879 | // Cast from i8* to the LLVM pointer type. |
880 | auto valueType = op.getResult().getType(); |
881 | auto llvmValueType = getTypeConverter()->convertType(valueType); |
882 | if (!llvmValueType) |
883 | return rewriter.notifyMatchFailure( |
884 | op, "failed to convert loaded value type to LLVM type"); |
885 | |
886 | Value castedStoragePtr = storagePtr.getResult(0); |
887 | |
888 | // Load from the casted pointer. |
889 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType, |
890 | castedStoragePtr); |
891 | |
892 | return success(); |
893 | } |
894 | }; |
895 | } // namespace |
896 | |
897 | //===----------------------------------------------------------------------===// |
898 | // Convert async.runtime.add_to_group to the corresponding runtime API call. |
899 | //===----------------------------------------------------------------------===// |
900 | |
901 | namespace { |
902 | class RuntimeAddToGroupOpLowering |
903 | : public OpConversionPattern<RuntimeAddToGroupOp> { |
904 | public: |
905 | using OpConversionPattern::OpConversionPattern; |
906 | |
907 | LogicalResult |
908 | matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, |
909 | ConversionPatternRewriter &rewriter) const override { |
910 | // Currently we can only add tokens to the group. |
911 | if (!isa<TokenType>(op.getOperand().getType())) |
912 | return rewriter.notifyMatchFailure(op, "only token type is supported"); |
913 | |
914 | // Replace with a runtime API function call. |
915 | rewriter.replaceOpWithNewOp<func::CallOp>( |
916 | op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); |
917 | |
918 | return success(); |
919 | } |
920 | }; |
921 | } // namespace |
922 | |
923 | //===----------------------------------------------------------------------===// |
924 | // Convert async.runtime.num_worker_threads to the corresponding runtime API |
925 | // call. |
926 | //===----------------------------------------------------------------------===// |
927 | |
928 | namespace { |
929 | class RuntimeNumWorkerThreadsOpLowering |
930 | : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { |
931 | public: |
932 | using OpConversionPattern::OpConversionPattern; |
933 | |
934 | LogicalResult |
935 | matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, |
936 | ConversionPatternRewriter &rewriter) const override { |
937 | |
938 | // Replace with a runtime API function call. |
939 | rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads, |
940 | rewriter.getIndexType()); |
941 | |
942 | return success(); |
943 | } |
944 | }; |
945 | } // namespace |
946 | |
947 | //===----------------------------------------------------------------------===// |
948 | // Async reference counting ops lowering (`async.runtime.add_ref` and |
949 | // `async.runtime.drop_ref` to the corresponding API calls). |
950 | //===----------------------------------------------------------------------===// |
951 | |
952 | namespace { |
953 | template <typename RefCountingOp> |
954 | class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { |
955 | public: |
956 | explicit RefCountingOpLowering(const TypeConverter &converter, |
957 | MLIRContext *ctx, StringRef apiFunctionName) |
958 | : OpConversionPattern<RefCountingOp>(converter, ctx), |
959 | apiFunctionName(apiFunctionName) {} |
960 | |
961 | LogicalResult |
962 | matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, |
963 | ConversionPatternRewriter &rewriter) const override { |
964 | auto count = rewriter.create<arith::ConstantOp>( |
965 | op->getLoc(), rewriter.getI64Type(), |
966 | rewriter.getI64IntegerAttr(op.getCount())); |
967 | |
968 | auto operand = adaptor.getOperand(); |
969 | rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, |
970 | ValueRange({operand, count})); |
971 | |
972 | return success(); |
973 | } |
974 | |
975 | private: |
976 | StringRef apiFunctionName; |
977 | }; |
978 | |
979 | class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { |
980 | public: |
981 | explicit RuntimeAddRefOpLowering(const TypeConverter &converter, |
982 | MLIRContext *ctx) |
983 | : RefCountingOpLowering(converter, ctx, kAddRef) {} |
984 | }; |
985 | |
986 | class RuntimeDropRefOpLowering |
987 | : public RefCountingOpLowering<RuntimeDropRefOp> { |
988 | public: |
989 | explicit RuntimeDropRefOpLowering(const TypeConverter &converter, |
990 | MLIRContext *ctx) |
991 | : RefCountingOpLowering(converter, ctx, kDropRef) {} |
992 | }; |
993 | } // namespace |
994 | |
995 | //===----------------------------------------------------------------------===// |
996 | // Convert return operations that return async values from async regions. |
997 | //===----------------------------------------------------------------------===// |
998 | |
999 | namespace { |
1000 | class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { |
1001 | public: |
1002 | using OpConversionPattern::OpConversionPattern; |
1003 | |
1004 | LogicalResult |
1005 | matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, |
1006 | ConversionPatternRewriter &rewriter) const override { |
1007 | rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); |
1008 | return success(); |
1009 | } |
1010 | }; |
1011 | } // namespace |
1012 | |
1013 | //===----------------------------------------------------------------------===// |
1014 | |
1015 | namespace { |
1016 | struct ConvertAsyncToLLVMPass |
1017 | : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> { |
1018 | using Base::Base; |
1019 | |
1020 | void runOnOperation() override; |
1021 | }; |
1022 | } // namespace |
1023 | |
1024 | void ConvertAsyncToLLVMPass::runOnOperation() { |
1025 | ModuleOp module = getOperation(); |
1026 | MLIRContext *ctx = module->getContext(); |
1027 | |
1028 | LowerToLLVMOptions options(ctx); |
1029 | |
1030 | // Add declarations for most functions required by the coroutines lowering. |
1031 | // We delay adding the resume function until it's needed because it currently |
1032 | // fails to compile unless '-O0' is specified. |
1033 | addAsyncRuntimeApiDeclarations(module); |
1034 | |
1035 | // Lower async.runtime and async.coro operations to Async Runtime API and |
1036 | // LLVM coroutine intrinsics. |
1037 | |
1038 | // Convert async dialect types and operations to LLVM dialect. |
1039 | AsyncRuntimeTypeConverter converter(options); |
1040 | RewritePatternSet patterns(ctx); |
1041 | |
1042 | // We use conversion to LLVM type to lower async.runtime load and store |
1043 | // operations. |
1044 | LLVMTypeConverter llvmConverter(ctx, options); |
1045 | llvmConverter.addConversion(callback: [&](Type type) { |
1046 | return AsyncRuntimeTypeConverter::convertAsyncTypes(type); |
1047 | }); |
1048 | |
1049 | // Convert async types in function signatures and function calls. |
1050 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
1051 | converter); |
1052 | populateCallOpTypeConversionPattern(patterns, converter); |
1053 | |
1054 | // Convert return operations inside async.execute regions. |
1055 | patterns.add<ReturnOpOpConversion>(arg&: converter, args&: ctx); |
1056 | |
1057 | // Lower async.runtime operations to the async runtime API calls. |
1058 | patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, |
1059 | RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, |
1060 | RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, |
1061 | RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, |
1062 | RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(arg&: converter, |
1063 | args&: ctx); |
1064 | |
1065 | // Lower async.runtime operations that rely on LLVM type converter to convert |
1066 | // from async value payload type to the LLVM type. |
1067 | patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, |
1068 | RuntimeStoreOpLowering, RuntimeLoadOpLowering>(arg&: llvmConverter); |
1069 | |
1070 | // Lower async coroutine operations to LLVM coroutine intrinsics. |
1071 | patterns |
1072 | .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, |
1073 | CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( |
1074 | arg&: converter, args&: ctx); |
1075 | |
1076 | ConversionTarget target(*ctx); |
1077 | target.addLegalOp<arith::ConstantOp, func::ConstantOp, |
1078 | UnrealizedConversionCastOp>(); |
1079 | target.addLegalDialect<LLVM::LLVMDialect>(); |
1080 | |
1081 | // All operations from Async dialect must be lowered to the runtime API and |
1082 | // LLVM intrinsics calls. |
1083 | target.addIllegalDialect<AsyncDialect>(); |
1084 | |
1085 | // Add dynamic legality constraints to apply conversions defined above. |
1086 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1087 | return converter.isSignatureLegal(op.getFunctionType()); |
1088 | }); |
1089 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
1090 | return converter.isLegal(op.getOperandTypes()); |
1091 | }); |
1092 | target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { |
1093 | return converter.isSignatureLegal(op.getCalleeType()); |
1094 | }); |
1095 | |
1096 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
1097 | signalPassFailure(); |
1098 | } |
1099 | |
1100 | //===----------------------------------------------------------------------===// |
1101 | // Patterns for structural type conversions for the Async dialect operations. |
1102 | //===----------------------------------------------------------------------===// |
1103 | |
1104 | namespace { |
1105 | class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { |
1106 | public: |
1107 | using OpConversionPattern::OpConversionPattern; |
1108 | LogicalResult |
1109 | matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, |
1110 | ConversionPatternRewriter &rewriter) const override { |
1111 | ExecuteOp newOp = |
1112 | cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); |
1113 | rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), |
1114 | newOp.getRegion().end()); |
1115 | |
1116 | // Set operands and update block argument and result types. |
1117 | newOp->setOperands(adaptor.getOperands()); |
1118 | if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) |
1119 | return failure(); |
1120 | for (auto result : newOp.getResults()) |
1121 | result.setType(typeConverter->convertType(result.getType())); |
1122 | |
1123 | rewriter.replaceOp(op, newOp.getResults()); |
1124 | return success(); |
1125 | } |
1126 | }; |
1127 | |
1128 | // Dummy pattern to trigger the appropriate type conversion / materialization. |
1129 | class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { |
1130 | public: |
1131 | using OpConversionPattern::OpConversionPattern; |
1132 | LogicalResult |
1133 | matchAndRewrite(AwaitOp op, OpAdaptor adaptor, |
1134 | ConversionPatternRewriter &rewriter) const override { |
1135 | rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); |
1136 | return success(); |
1137 | } |
1138 | }; |
1139 | |
1140 | // Dummy pattern to trigger the appropriate type conversion / materialization. |
1141 | class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { |
1142 | public: |
1143 | using OpConversionPattern::OpConversionPattern; |
1144 | LogicalResult |
1145 | matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, |
1146 | ConversionPatternRewriter &rewriter) const override { |
1147 | rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); |
1148 | return success(); |
1149 | } |
1150 | }; |
1151 | } // namespace |
1152 | |
1153 | void mlir::populateAsyncStructuralTypeConversionsAndLegality( |
1154 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
1155 | ConversionTarget &target) { |
1156 | typeConverter.addConversion([&](TokenType type) { return type; }); |
1157 | typeConverter.addConversion(callback: [&](ValueType type) { |
1158 | Type converted = typeConverter.convertType(type.getValueType()); |
1159 | return converted ? ValueType::get(converted) : converted; |
1160 | }); |
1161 | |
1162 | patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( |
1163 | arg&: typeConverter, args: patterns.getContext()); |
1164 | |
1165 | target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( |
1166 | [&](Operation *op) { return typeConverter.isLegal(op); }); |
1167 | } |
1168 |
Definitions
- kAddRef
- kDropRef
- kCreateToken
- kCreateValue
- kCreateGroup
- kEmplaceToken
- kEmplaceValue
- kSetTokenError
- kSetValueError
- kIsTokenError
- kIsValueError
- kIsGroupError
- kAwaitToken
- kAwaitValue
- kAwaitGroup
- kExecute
- kGetValueStorage
- kAddTokenToGroup
- kAwaitTokenAndExecute
- kAwaitValueAndExecute
- kAwaitAllAndExecute
- kGetNumWorkerThreads
- AsyncAPI
- opaquePointerType
- tokenType
- addOrDropRefFunctionType
- createTokenFunctionType
- createValueFunctionType
- createGroupFunctionType
- getValueStorageFunctionType
- emplaceTokenFunctionType
- emplaceValueFunctionType
- setTokenErrorFunctionType
- setValueErrorFunctionType
- isTokenErrorFunctionType
- isValueErrorFunctionType
- isGroupErrorFunctionType
- awaitTokenFunctionType
- awaitValueFunctionType
- awaitGroupFunctionType
- executeFunctionType
- addTokenToGroupFunctionType
- awaitTokenAndExecuteFunctionType
- awaitValueAndExecuteFunctionType
- awaitAllAndExecuteFunctionType
- getNumWorkerThreads
- resumeFunctionType
- addAsyncRuntimeApiDeclarations
- kResume
- addResumeFunction
- AsyncRuntimeTypeConverter
- AsyncRuntimeTypeConverter
- convertAsyncTypes
- AsyncOpConversionPattern
- AsyncOpConversionPattern
- getTypeConverter
- CoroIdOpConversion
- matchAndRewrite
- CoroBeginOpConversion
- matchAndRewrite
- CoroFreeOpConversion
- matchAndRewrite
- CoroEndOpConversion
- matchAndRewrite
- CoroSaveOpConversion
- matchAndRewrite
- CoroSuspendOpConversion
- matchAndRewrite
- RuntimeCreateOpLowering
- matchAndRewrite
- RuntimeCreateGroupOpLowering
- matchAndRewrite
- RuntimeSetAvailableOpLowering
- matchAndRewrite
- RuntimeSetErrorOpLowering
- matchAndRewrite
- RuntimeIsErrorOpLowering
- matchAndRewrite
- RuntimeAwaitOpLowering
- matchAndRewrite
- RuntimeAwaitAndResumeOpLowering
- matchAndRewrite
- RuntimeResumeOpLowering
- matchAndRewrite
- RuntimeStoreOpLowering
- matchAndRewrite
- RuntimeLoadOpLowering
- matchAndRewrite
- RuntimeAddToGroupOpLowering
- matchAndRewrite
- RuntimeNumWorkerThreadsOpLowering
- matchAndRewrite
- RefCountingOpLowering
- RefCountingOpLowering
- matchAndRewrite
- RuntimeAddRefOpLowering
- RuntimeAddRefOpLowering
- RuntimeDropRefOpLowering
- RuntimeDropRefOpLowering
- ReturnOpOpConversion
- matchAndRewrite
- ConvertAsyncToLLVMPass
- runOnOperation
- ConvertExecuteOpTypes
- matchAndRewrite
- ConvertAwaitOpTypes
- matchAndRewrite
- ConvertYieldOpTypes
- matchAndRewrite
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more