| 1 | //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements automatic reference counting for Async runtime |
| 10 | // operations and types. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Async/Passes.h" |
| 15 | |
| 16 | #include "mlir/Analysis/Liveness.h" |
| 17 | #include "mlir/Dialect/Async/IR/Async.h" |
| 18 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 20 | #include "llvm/ADT/SmallSet.h" |
| 21 | |
| 22 | namespace mlir { |
| 23 | #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGPASS |
| 24 | #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTINGPASS |
| 25 | #include "mlir/Dialect/Async/Passes.h.inc" |
| 26 | } // namespace mlir |
| 27 | |
| 28 | #define DEBUG_TYPE "async-runtime-ref-counting" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::async; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Utility functions shared by reference counting passes. |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | // Drop the reference count immediately if the value has no uses. |
| 38 | static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { |
| 39 | if (!value.getUses().empty()) |
| 40 | return failure(); |
| 41 | |
| 42 | OpBuilder b(value.getContext()); |
| 43 | |
| 44 | // Set insertion point after the operation producing a value, or at the |
| 45 | // beginning of the block if the value defined by the block argument. |
| 46 | if (Operation *op = value.getDefiningOp()) |
| 47 | b.setInsertionPointAfter(op); |
| 48 | else |
| 49 | b.setInsertionPointToStart(value.getParentBlock()); |
| 50 | |
| 51 | b.create<RuntimeDropRefOp>(location: value.getLoc(), args&: value, args: b.getI64IntegerAttr(value: 1)); |
| 52 | return success(); |
| 53 | } |
| 54 | |
| 55 | // Calls `addRefCounting` for every reference counted value defined by the |
| 56 | // operation `op` (block arguments and values defined in nested regions). |
| 57 | static LogicalResult walkReferenceCountedValues( |
| 58 | Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { |
| 59 | // Check that we do not have high level async operations in the IR because |
| 60 | // otherwise reference counting will produce incorrect results after high |
| 61 | // level async operations will be lowered to `async.runtime` |
| 62 | WalkResult checkNoAsyncWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
| 63 | if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(Val: op)) |
| 64 | return WalkResult::advance(); |
| 65 | |
| 66 | return op->emitError() |
| 67 | << "async operations must be lowered to async runtime operations" ; |
| 68 | }); |
| 69 | |
| 70 | if (checkNoAsyncWalk.wasInterrupted()) |
| 71 | return failure(); |
| 72 | |
| 73 | // Add reference counting to block arguments. |
| 74 | WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult { |
| 75 | for (BlockArgument arg : block->getArguments()) |
| 76 | if (isRefCounted(type: arg.getType())) |
| 77 | if (failed(Result: addRefCounting(arg))) |
| 78 | return WalkResult::interrupt(); |
| 79 | |
| 80 | return WalkResult::advance(); |
| 81 | }); |
| 82 | |
| 83 | if (blockWalk.wasInterrupted()) |
| 84 | return failure(); |
| 85 | |
| 86 | // Add reference counting to operation results. |
| 87 | WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
| 88 | for (unsigned i = 0; i < op->getNumResults(); ++i) |
| 89 | if (isRefCounted(type: op->getResultTypes()[i])) |
| 90 | if (failed(Result: addRefCounting(op->getResult(idx: i)))) |
| 91 | return WalkResult::interrupt(); |
| 92 | |
| 93 | return WalkResult::advance(); |
| 94 | }); |
| 95 | |
| 96 | if (opWalk.wasInterrupted()) |
| 97 | return failure(); |
| 98 | |
| 99 | return success(); |
| 100 | } |
| 101 | |
| 102 | //===----------------------------------------------------------------------===// |
| 103 | // Automatic reference counting based on the liveness analysis. |
| 104 | //===----------------------------------------------------------------------===// |
| 105 | |
| 106 | namespace { |
| 107 | |
| 108 | class AsyncRuntimeRefCountingPass |
| 109 | : public impl::AsyncRuntimeRefCountingPassBase< |
| 110 | AsyncRuntimeRefCountingPass> { |
| 111 | public: |
| 112 | AsyncRuntimeRefCountingPass() = default; |
| 113 | void runOnOperation() override; |
| 114 | |
| 115 | private: |
| 116 | /// Adds an automatic reference counting to the `value`. |
| 117 | /// |
| 118 | /// All values (token, group or value) are semantically created with a |
| 119 | /// reference count of +1 and it is the responsibility of the async value user |
| 120 | /// to place the `add_ref` and `drop_ref` operations to ensure that the value |
| 121 | /// is destroyed after the last use. |
| 122 | /// |
| 123 | /// The function returns failure if it can't deduce the locations where |
| 124 | /// to place the reference counting operations. |
| 125 | /// |
| 126 | /// Async values "semantically created" when: |
| 127 | /// 1. Operation returns async result (e.g. `async.runtime.create`) |
| 128 | /// 2. Async value passed in as a block argument (or function argument, |
| 129 | /// because function arguments are just entry block arguments) |
| 130 | /// |
| 131 | /// Passing async value as a function argument (or block argument) does not |
| 132 | /// really mean that a new async value is created, it only means that the |
| 133 | /// caller of a function transfered ownership of `+1` reference to the callee. |
| 134 | /// It is convenient to think that from the callee perspective async value was |
| 135 | /// "created" with `+1` reference by the block argument. |
| 136 | /// |
| 137 | /// Automatic reference counting algorithm outline: |
| 138 | /// |
| 139 | /// #1 Insert `drop_ref` operations after last use of the `value`. |
| 140 | /// #2 Insert `add_ref` operations before functions calls with reference |
| 141 | /// counted `value` operand (newly created `+1` reference will be |
| 142 | /// transferred to the callee). |
| 143 | /// #3 Verify that divergent control flow does not lead to leaked reference |
| 144 | /// counted objects. |
| 145 | /// |
| 146 | /// Async runtime reference counting optimization pass will optimize away |
| 147 | /// some of the redundant `add_ref` and `drop_ref` operations inserted by this |
| 148 | /// strategy (see `async-runtime-ref-counting-opt`). |
| 149 | LogicalResult addAutomaticRefCounting(Value value); |
| 150 | |
| 151 | /// (#1) Adds the `drop_ref` operation after the last use of the `value` |
| 152 | /// relying on the liveness analysis. |
| 153 | /// |
| 154 | /// If the `value` is in the block `liveIn` set and it is not in the block |
| 155 | /// `liveOut` set, it means that it "dies" in the block. We find the last |
| 156 | /// use of the value in such block and: |
| 157 | /// |
| 158 | /// 1. If the last user is a `ReturnLike` operation we do nothing, because |
| 159 | /// it forwards the ownership to the caller. |
| 160 | /// 2. Otherwise we add a `drop_ref` operation immediately after the last |
| 161 | /// use. |
| 162 | LogicalResult addDropRefAfterLastUse(Value value); |
| 163 | |
| 164 | /// (#2) Adds the `add_ref` operation before the function call taking `value` |
| 165 | /// operand to ensure that the value passed to the function entry block |
| 166 | /// has a `+1` reference count. |
| 167 | LogicalResult addAddRefBeforeFunctionCall(Value value); |
| 168 | |
| 169 | /// (#3) Adds the `drop_ref` operation to account for successor blocks with |
| 170 | /// divergent `liveIn` property: `value` is not in the `liveIn` set of all |
| 171 | /// successor blocks. |
| 172 | /// |
| 173 | /// Example: |
| 174 | /// |
| 175 | /// ^entry: |
| 176 | /// %token = async.runtime.create : !async.token |
| 177 | /// cf.cond_br %cond, ^bb1, ^bb2 |
| 178 | /// ^bb1: |
| 179 | /// async.runtime.await %token |
| 180 | /// async.runtime.drop_ref %token |
| 181 | /// cf.br ^bb2 |
| 182 | /// ^bb2: |
| 183 | /// return |
| 184 | /// |
| 185 | /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have |
| 186 | /// to branch into a special "reference counting block" from the ^entry that |
| 187 | /// will have a `drop_ref` operation, and then branch into the ^bb2. |
| 188 | /// |
| 189 | /// After transformation: |
| 190 | /// |
| 191 | /// ^entry: |
| 192 | /// %token = async.runtime.create : !async.token |
| 193 | /// cf.cond_br %cond, ^bb1, ^reference_counting |
| 194 | /// ^bb1: |
| 195 | /// async.runtime.await %token |
| 196 | /// async.runtime.drop_ref %token |
| 197 | /// cf.br ^bb2 |
| 198 | /// ^reference_counting: |
| 199 | /// async.runtime.drop_ref %token |
| 200 | /// cf.br ^bb2 |
| 201 | /// ^bb2: |
| 202 | /// return |
| 203 | /// |
| 204 | /// An exception to this rule are blocks with `async.coro.suspend` terminator, |
| 205 | /// because in Async to LLVM lowering it is guaranteed that the control flow |
| 206 | /// will jump into the resume block, and then follow into the cleanup and |
| 207 | /// suspend blocks. |
| 208 | /// |
| 209 | /// Example: |
| 210 | /// |
| 211 | /// ^entry(%value: !async.value<f32>): |
| 212 | /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> |
| 213 | /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup |
| 214 | /// ^resume: |
| 215 | /// %0 = async.runtime.load %value |
| 216 | /// cf.br ^cleanup |
| 217 | /// ^cleanup: |
| 218 | /// ... |
| 219 | /// ^suspend: |
| 220 | /// ... |
| 221 | /// |
| 222 | /// Although cleanup and suspend blocks do not have the `value` in the |
| 223 | /// `liveIn` set, it is guaranteed that execution will eventually continue in |
| 224 | /// the resume block (we never explicitly destroy coroutines). |
| 225 | LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); |
| 226 | }; |
| 227 | |
| 228 | } // namespace |
| 229 | |
| 230 | LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { |
| 231 | OpBuilder builder(value.getContext()); |
| 232 | Location loc = value.getLoc(); |
| 233 | |
| 234 | // Use liveness analysis to find the placement of `drop_ref`operation. |
| 235 | auto &liveness = getAnalysis<Liveness>(); |
| 236 | |
| 237 | // We analyse only the blocks of the region that defines the `value`, and do |
| 238 | // not check nested blocks attached to operations. |
| 239 | // |
| 240 | // By analyzing only the `definingRegion` CFG we potentially loose an |
| 241 | // opportunity to drop the reference count earlier and can extend the lifetime |
| 242 | // of reference counted value longer then it is really required. |
| 243 | // |
| 244 | // We also assume that all nested regions finish their execution before the |
| 245 | // completion of the owner operation. The only exception to this rule is |
| 246 | // `async.execute` operation, and we verify that they are lowered to the |
| 247 | // `async.runtime` operations before adding automatic reference counting. |
| 248 | Region *definingRegion = value.getParentRegion(); |
| 249 | |
| 250 | // Last users of the `value` inside all blocks where the value dies. |
| 251 | llvm::SmallSet<Operation *, 4> lastUsers; |
| 252 | |
| 253 | // Find blocks in the `definingRegion` that have users of the `value` (if |
| 254 | // there are multiple users in the block, which one will be selected is |
| 255 | // undefined). User operation might be not the actual user of the value, but |
| 256 | // the operation in the block that has a "real user" in one of the attached |
| 257 | // regions. |
| 258 | llvm::DenseMap<Block *, Operation *> usersInTheBlocks; |
| 259 | |
| 260 | for (Operation *user : value.getUsers()) { |
| 261 | Block *userBlock = user->getBlock(); |
| 262 | Block *ancestor = definingRegion->findAncestorBlockInRegion(block&: *userBlock); |
| 263 | usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(op&: *user); |
| 264 | assert(ancestor && "ancestor block must be not null" ); |
| 265 | assert(usersInTheBlocks[ancestor] && "ancestor op must be not null" ); |
| 266 | } |
| 267 | |
| 268 | // Find blocks where the `value` dies: the value is in `liveIn` set and not |
| 269 | // in the `liveOut` set. We place `drop_ref` immediately after the last use |
| 270 | // of the `value` in such regions (after handling few special cases). |
| 271 | // |
| 272 | // We do not traverse all the blocks in the `definingRegion`, because the |
| 273 | // `value` can be in the live in set only if it has users in the block, or it |
| 274 | // is defined in the block. |
| 275 | // |
| 276 | // Values with zero users (only definition) handled explicitly above. |
| 277 | for (auto &blockAndUser : usersInTheBlocks) { |
| 278 | Block *block = blockAndUser.getFirst(); |
| 279 | Operation *userInTheBlock = blockAndUser.getSecond(); |
| 280 | |
| 281 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); |
| 282 | |
| 283 | // Value must be in the live input set or defined in the block. |
| 284 | assert(blockLiveness->isLiveIn(value) || |
| 285 | blockLiveness->getBlock() == value.getParentBlock()); |
| 286 | |
| 287 | // If value is in the live out set, it means it doesn't "die" in the block. |
| 288 | if (blockLiveness->isLiveOut(value)) |
| 289 | continue; |
| 290 | |
| 291 | // At this point we proved that `value` dies in the `block`. Find the last |
| 292 | // use of the `value` inside the `block`, this is where it "dies". |
| 293 | Operation *lastUser = blockLiveness->getEndOperation(value, startOperation: userInTheBlock); |
| 294 | assert(lastUsers.count(lastUser) == 0 && "last users must be unique" ); |
| 295 | lastUsers.insert(Ptr: lastUser); |
| 296 | } |
| 297 | |
| 298 | // Process all the last users of the `value` inside each block where the value |
| 299 | // dies. |
| 300 | for (Operation *lastUser : lastUsers) { |
| 301 | // Return like operations forward reference count. |
| 302 | if (lastUser->hasTrait<OpTrait::ReturnLike>()) |
| 303 | continue; |
| 304 | |
| 305 | // We can't currently handle other types of terminators. |
| 306 | if (lastUser->hasTrait<OpTrait::IsTerminator>()) |
| 307 | return lastUser->emitError() << "async reference counting can't handle " |
| 308 | "terminators that are not ReturnLike" ; |
| 309 | |
| 310 | // Add a drop_ref immediately after the last user. |
| 311 | builder.setInsertionPointAfter(lastUser); |
| 312 | builder.create<RuntimeDropRefOp>(location: loc, args&: value, args: builder.getI64IntegerAttr(value: 1)); |
| 313 | } |
| 314 | |
| 315 | return success(); |
| 316 | } |
| 317 | |
| 318 | LogicalResult |
| 319 | AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { |
| 320 | OpBuilder builder(value.getContext()); |
| 321 | Location loc = value.getLoc(); |
| 322 | |
| 323 | for (Operation *user : value.getUsers()) { |
| 324 | if (!isa<func::CallOp>(Val: user)) |
| 325 | continue; |
| 326 | |
| 327 | // Add a reference before the function call to pass the value at `+1` |
| 328 | // reference to the function entry block. |
| 329 | builder.setInsertionPoint(user); |
| 330 | builder.create<RuntimeAddRefOp>(location: loc, args&: value, args: builder.getI64IntegerAttr(value: 1)); |
| 331 | } |
| 332 | |
| 333 | return success(); |
| 334 | } |
| 335 | |
| 336 | LogicalResult |
| 337 | AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( |
| 338 | Value value) { |
| 339 | using BlockSet = llvm::SmallPtrSet<Block *, 4>; |
| 340 | |
| 341 | OpBuilder builder(value.getContext()); |
| 342 | |
| 343 | // If a block has successors with different `liveIn` property of the `value`, |
| 344 | // record block successors that do not thave the `value` in the `liveIn` set. |
| 345 | llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; |
| 346 | |
| 347 | // Use liveness analysis to find the placement of `drop_ref`operation. |
| 348 | auto &liveness = getAnalysis<Liveness>(); |
| 349 | |
| 350 | // Because we only add `drop_ref` operations to the region that defines the |
| 351 | // `value` we can only process CFG for the same region. |
| 352 | Region *definingRegion = value.getParentRegion(); |
| 353 | |
| 354 | // Collect blocks with successors with mismatching `liveIn` sets. |
| 355 | for (Block &block : definingRegion->getBlocks()) { |
| 356 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block: &block); |
| 357 | |
| 358 | // Skip the block if value is not in the `liveOut` set. |
| 359 | if (!blockLiveness || !blockLiveness->isLiveOut(value)) |
| 360 | continue; |
| 361 | |
| 362 | BlockSet liveInSuccessors; // `value` is in `liveIn` set |
| 363 | BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set |
| 364 | |
| 365 | // Collect successors that do not have `value` in the `liveIn` set. |
| 366 | for (Block *successor : block.getSuccessors()) { |
| 367 | const LivenessBlockInfo *succLiveness = liveness.getLiveness(block: successor); |
| 368 | if (succLiveness && succLiveness->isLiveIn(value)) |
| 369 | liveInSuccessors.insert(Ptr: successor); |
| 370 | else |
| 371 | noLiveInSuccessors.insert(Ptr: successor); |
| 372 | } |
| 373 | |
| 374 | // Block has successors with different `liveIn` property of the `value`. |
| 375 | if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) |
| 376 | divergentLivenessBlocks.try_emplace(Key: &block, Args&: noLiveInSuccessors); |
| 377 | } |
| 378 | |
| 379 | // Try to insert `dropRef` operations to handle blocks with divergent liveness |
| 380 | // in successors blocks. |
| 381 | for (auto kv : divergentLivenessBlocks) { |
| 382 | Block *block = kv.getFirst(); |
| 383 | BlockSet &successors = kv.getSecond(); |
| 384 | |
| 385 | // Coroutine suspension is a special case terminator for wich we do not |
| 386 | // need to create additional reference counting (see details above). |
| 387 | Operation *terminator = block->getTerminator(); |
| 388 | if (isa<CoroSuspendOp>(Val: terminator)) |
| 389 | continue; |
| 390 | |
| 391 | // We only support successor blocks with empty block argument list. |
| 392 | auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; |
| 393 | if (llvm::any_of(Range&: successors, P: hasArgs)) |
| 394 | return terminator->emitOpError() |
| 395 | << "successor have different `liveIn` property of the reference " |
| 396 | "counted value" ; |
| 397 | |
| 398 | // Make sure that `dropRef` operation is called when branched into the |
| 399 | // successor block without `value` in the `liveIn` set. |
| 400 | for (Block *successor : successors) { |
| 401 | // If successor has a unique predecessor, it is safe to create `dropRef` |
| 402 | // operations directly in the successor block. |
| 403 | // |
| 404 | // Otherwise we need to create a special block for reference counting |
| 405 | // operations, and branch from it to the original successor block. |
| 406 | Block *refCountingBlock = nullptr; |
| 407 | |
| 408 | if (successor->getUniquePredecessor() == block) { |
| 409 | refCountingBlock = successor; |
| 410 | } else { |
| 411 | refCountingBlock = &successor->getParent()->emplaceBlock(); |
| 412 | refCountingBlock->moveBefore(block: successor); |
| 413 | OpBuilder builder = OpBuilder::atBlockEnd(block: refCountingBlock); |
| 414 | builder.create<cf::BranchOp>(location: value.getLoc(), args&: successor); |
| 415 | } |
| 416 | |
| 417 | OpBuilder builder = OpBuilder::atBlockBegin(block: refCountingBlock); |
| 418 | builder.create<RuntimeDropRefOp>(location: value.getLoc(), args&: value, |
| 419 | args: builder.getI64IntegerAttr(value: 1)); |
| 420 | |
| 421 | // No need to update the terminator operation. |
| 422 | if (successor == refCountingBlock) |
| 423 | continue; |
| 424 | |
| 425 | // Update terminator `successor` block to `refCountingBlock`. |
| 426 | for (const auto &pair : llvm::enumerate(First: terminator->getSuccessors())) |
| 427 | if (pair.value() == successor) |
| 428 | terminator->setSuccessor(block: refCountingBlock, index: pair.index()); |
| 429 | } |
| 430 | } |
| 431 | |
| 432 | return success(); |
| 433 | } |
| 434 | |
| 435 | LogicalResult |
| 436 | AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { |
| 437 | // Short-circuit reference counting for values without uses. |
| 438 | if (succeeded(Result: dropRefIfNoUses(value))) |
| 439 | return success(); |
| 440 | |
| 441 | // Add `drop_ref` operations based on the liveness analysis. |
| 442 | if (failed(Result: addDropRefAfterLastUse(value))) |
| 443 | return failure(); |
| 444 | |
| 445 | // Add `add_ref` operations before function calls. |
| 446 | if (failed(Result: addAddRefBeforeFunctionCall(value))) |
| 447 | return failure(); |
| 448 | |
| 449 | // Add `drop_ref` operations to successors with divergent `value` liveness. |
| 450 | if (failed(Result: addDropRefInDivergentLivenessSuccessor(value))) |
| 451 | return failure(); |
| 452 | |
| 453 | return success(); |
| 454 | } |
| 455 | |
| 456 | void AsyncRuntimeRefCountingPass::runOnOperation() { |
| 457 | auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; |
| 458 | if (failed(Result: walkReferenceCountedValues(op: getOperation(), addRefCounting: functor))) |
| 459 | signalPassFailure(); |
| 460 | } |
| 461 | |
| 462 | //===----------------------------------------------------------------------===// |
| 463 | // Reference counting based on the user defined policy. |
| 464 | //===----------------------------------------------------------------------===// |
| 465 | |
| 466 | namespace { |
| 467 | |
| 468 | class AsyncRuntimePolicyBasedRefCountingPass |
| 469 | : public impl::AsyncRuntimePolicyBasedRefCountingPassBase< |
| 470 | AsyncRuntimePolicyBasedRefCountingPass> { |
| 471 | public: |
| 472 | AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } |
| 473 | |
| 474 | void runOnOperation() override; |
| 475 | |
| 476 | private: |
| 477 | // Adds a reference counting operations for all uses of the `value` according |
| 478 | // to the reference counting policy. |
| 479 | LogicalResult addRefCounting(Value value); |
| 480 | |
| 481 | void initializeDefaultPolicy(); |
| 482 | |
| 483 | llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; |
| 484 | }; |
| 485 | |
| 486 | } // namespace |
| 487 | |
| 488 | LogicalResult |
| 489 | AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { |
| 490 | // Short-circuit reference counting for values without uses. |
| 491 | if (succeeded(Result: dropRefIfNoUses(value))) |
| 492 | return success(); |
| 493 | |
| 494 | OpBuilder b(value.getContext()); |
| 495 | |
| 496 | // Consult the user defined policy for every value use. |
| 497 | for (OpOperand &operand : value.getUses()) { |
| 498 | Location loc = operand.getOwner()->getLoc(); |
| 499 | |
| 500 | for (auto &func : policy) { |
| 501 | FailureOr<int> refCount = func(operand); |
| 502 | if (failed(Result: refCount)) |
| 503 | return failure(); |
| 504 | |
| 505 | int cnt = *refCount; |
| 506 | |
| 507 | // Create `add_ref` operation before the operand owner. |
| 508 | if (cnt > 0) { |
| 509 | b.setInsertionPoint(operand.getOwner()); |
| 510 | b.create<RuntimeAddRefOp>(location: loc, args&: value, args: b.getI64IntegerAttr(value: cnt)); |
| 511 | } |
| 512 | |
| 513 | // Create `drop_ref` operation after the operand owner. |
| 514 | if (cnt < 0) { |
| 515 | b.setInsertionPointAfter(operand.getOwner()); |
| 516 | b.create<RuntimeDropRefOp>(location: loc, args&: value, args: b.getI64IntegerAttr(value: -cnt)); |
| 517 | } |
| 518 | } |
| 519 | } |
| 520 | |
| 521 | return success(); |
| 522 | } |
| 523 | |
| 524 | void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { |
| 525 | policy.push_back(Elt: [](OpOperand &operand) -> FailureOr<int> { |
| 526 | Operation *op = operand.getOwner(); |
| 527 | Type type = operand.get().getType(); |
| 528 | |
| 529 | bool isToken = isa<TokenType>(Val: type); |
| 530 | bool isGroup = isa<GroupType>(Val: type); |
| 531 | bool isValue = isa<ValueType>(Val: type); |
| 532 | |
| 533 | // Drop reference after async token or group error check (coro await). |
| 534 | if (isa<RuntimeIsErrorOp>(Val: op)) |
| 535 | return (isToken || isGroup) ? -1 : 0; |
| 536 | |
| 537 | // Drop reference after async value load. |
| 538 | if (isa<RuntimeLoadOp>(Val: op)) |
| 539 | return isValue ? -1 : 0; |
| 540 | |
| 541 | // Drop reference after async token added to the group. |
| 542 | if (isa<RuntimeAddToGroupOp>(Val: op)) |
| 543 | return isToken ? -1 : 0; |
| 544 | |
| 545 | return 0; |
| 546 | }); |
| 547 | } |
| 548 | |
| 549 | void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { |
| 550 | auto functor = [&](Value value) { return addRefCounting(value); }; |
| 551 | if (failed(Result: walkReferenceCountedValues(op: getOperation(), addRefCounting: functor))) |
| 552 | signalPassFailure(); |
| 553 | } |
| 554 | |