| 1 | //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements automatic reference counting for Async runtime |
| 10 | // operations and types. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Async/Passes.h" |
| 15 | |
| 16 | #include "mlir/Analysis/Liveness.h" |
| 17 | #include "mlir/Dialect/Async/IR/Async.h" |
| 18 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 20 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 21 | #include "mlir/IR/PatternMatch.h" |
| 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | #include "llvm/ADT/SmallSet.h" |
| 24 | |
| 25 | namespace mlir { |
| 26 | #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGPASS |
| 27 | #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTINGPASS |
| 28 | #include "mlir/Dialect/Async/Passes.h.inc" |
| 29 | } // namespace mlir |
| 30 | |
| 31 | #define DEBUG_TYPE "async-runtime-ref-counting" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::async; |
| 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // Utility functions shared by reference counting passes. |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | // Drop the reference count immediately if the value has no uses. |
| 41 | static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { |
| 42 | if (!value.getUses().empty()) |
| 43 | return failure(); |
| 44 | |
| 45 | OpBuilder b(value.getContext()); |
| 46 | |
| 47 | // Set insertion point after the operation producing a value, or at the |
| 48 | // beginning of the block if the value defined by the block argument. |
| 49 | if (Operation *op = value.getDefiningOp()) |
| 50 | b.setInsertionPointAfter(op); |
| 51 | else |
| 52 | b.setInsertionPointToStart(value.getParentBlock()); |
| 53 | |
| 54 | b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); |
| 55 | return success(); |
| 56 | } |
| 57 | |
| 58 | // Calls `addRefCounting` for every reference counted value defined by the |
| 59 | // operation `op` (block arguments and values defined in nested regions). |
| 60 | static LogicalResult walkReferenceCountedValues( |
| 61 | Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { |
| 62 | // Check that we do not have high level async operations in the IR because |
| 63 | // otherwise reference counting will produce incorrect results after high |
| 64 | // level async operations will be lowered to `async.runtime` |
| 65 | WalkResult checkNoAsyncWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
| 66 | if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) |
| 67 | return WalkResult::advance(); |
| 68 | |
| 69 | return op->emitError() |
| 70 | << "async operations must be lowered to async runtime operations" ; |
| 71 | }); |
| 72 | |
| 73 | if (checkNoAsyncWalk.wasInterrupted()) |
| 74 | return failure(); |
| 75 | |
| 76 | // Add reference counting to block arguments. |
| 77 | WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult { |
| 78 | for (BlockArgument arg : block->getArguments()) |
| 79 | if (isRefCounted(type: arg.getType())) |
| 80 | if (failed(Result: addRefCounting(arg))) |
| 81 | return WalkResult::interrupt(); |
| 82 | |
| 83 | return WalkResult::advance(); |
| 84 | }); |
| 85 | |
| 86 | if (blockWalk.wasInterrupted()) |
| 87 | return failure(); |
| 88 | |
| 89 | // Add reference counting to operation results. |
| 90 | WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
| 91 | for (unsigned i = 0; i < op->getNumResults(); ++i) |
| 92 | if (isRefCounted(type: op->getResultTypes()[i])) |
| 93 | if (failed(Result: addRefCounting(op->getResult(idx: i)))) |
| 94 | return WalkResult::interrupt(); |
| 95 | |
| 96 | return WalkResult::advance(); |
| 97 | }); |
| 98 | |
| 99 | if (opWalk.wasInterrupted()) |
| 100 | return failure(); |
| 101 | |
| 102 | return success(); |
| 103 | } |
| 104 | |
| 105 | //===----------------------------------------------------------------------===// |
| 106 | // Automatic reference counting based on the liveness analysis. |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | |
| 109 | namespace { |
| 110 | |
| 111 | class AsyncRuntimeRefCountingPass |
| 112 | : public impl::AsyncRuntimeRefCountingPassBase< |
| 113 | AsyncRuntimeRefCountingPass> { |
| 114 | public: |
| 115 | AsyncRuntimeRefCountingPass() = default; |
| 116 | void runOnOperation() override; |
| 117 | |
| 118 | private: |
| 119 | /// Adds an automatic reference counting to the `value`. |
| 120 | /// |
| 121 | /// All values (token, group or value) are semantically created with a |
| 122 | /// reference count of +1 and it is the responsibility of the async value user |
| 123 | /// to place the `add_ref` and `drop_ref` operations to ensure that the value |
| 124 | /// is destroyed after the last use. |
| 125 | /// |
| 126 | /// The function returns failure if it can't deduce the locations where |
| 127 | /// to place the reference counting operations. |
| 128 | /// |
| 129 | /// Async values "semantically created" when: |
| 130 | /// 1. Operation returns async result (e.g. `async.runtime.create`) |
| 131 | /// 2. Async value passed in as a block argument (or function argument, |
| 132 | /// because function arguments are just entry block arguments) |
| 133 | /// |
| 134 | /// Passing async value as a function argument (or block argument) does not |
| 135 | /// really mean that a new async value is created, it only means that the |
| 136 | /// caller of a function transfered ownership of `+1` reference to the callee. |
| 137 | /// It is convenient to think that from the callee perspective async value was |
| 138 | /// "created" with `+1` reference by the block argument. |
| 139 | /// |
| 140 | /// Automatic reference counting algorithm outline: |
| 141 | /// |
| 142 | /// #1 Insert `drop_ref` operations after last use of the `value`. |
| 143 | /// #2 Insert `add_ref` operations before functions calls with reference |
| 144 | /// counted `value` operand (newly created `+1` reference will be |
| 145 | /// transferred to the callee). |
| 146 | /// #3 Verify that divergent control flow does not lead to leaked reference |
| 147 | /// counted objects. |
| 148 | /// |
| 149 | /// Async runtime reference counting optimization pass will optimize away |
| 150 | /// some of the redundant `add_ref` and `drop_ref` operations inserted by this |
| 151 | /// strategy (see `async-runtime-ref-counting-opt`). |
| 152 | LogicalResult addAutomaticRefCounting(Value value); |
| 153 | |
| 154 | /// (#1) Adds the `drop_ref` operation after the last use of the `value` |
| 155 | /// relying on the liveness analysis. |
| 156 | /// |
| 157 | /// If the `value` is in the block `liveIn` set and it is not in the block |
| 158 | /// `liveOut` set, it means that it "dies" in the block. We find the last |
| 159 | /// use of the value in such block and: |
| 160 | /// |
| 161 | /// 1. If the last user is a `ReturnLike` operation we do nothing, because |
| 162 | /// it forwards the ownership to the caller. |
| 163 | /// 2. Otherwise we add a `drop_ref` operation immediately after the last |
| 164 | /// use. |
| 165 | LogicalResult addDropRefAfterLastUse(Value value); |
| 166 | |
| 167 | /// (#2) Adds the `add_ref` operation before the function call taking `value` |
| 168 | /// operand to ensure that the value passed to the function entry block |
| 169 | /// has a `+1` reference count. |
| 170 | LogicalResult addAddRefBeforeFunctionCall(Value value); |
| 171 | |
| 172 | /// (#3) Adds the `drop_ref` operation to account for successor blocks with |
| 173 | /// divergent `liveIn` property: `value` is not in the `liveIn` set of all |
| 174 | /// successor blocks. |
| 175 | /// |
| 176 | /// Example: |
| 177 | /// |
| 178 | /// ^entry: |
| 179 | /// %token = async.runtime.create : !async.token |
| 180 | /// cf.cond_br %cond, ^bb1, ^bb2 |
| 181 | /// ^bb1: |
| 182 | /// async.runtime.await %token |
| 183 | /// async.runtime.drop_ref %token |
| 184 | /// cf.br ^bb2 |
| 185 | /// ^bb2: |
| 186 | /// return |
| 187 | /// |
| 188 | /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have |
| 189 | /// to branch into a special "reference counting block" from the ^entry that |
| 190 | /// will have a `drop_ref` operation, and then branch into the ^bb2. |
| 191 | /// |
| 192 | /// After transformation: |
| 193 | /// |
| 194 | /// ^entry: |
| 195 | /// %token = async.runtime.create : !async.token |
| 196 | /// cf.cond_br %cond, ^bb1, ^reference_counting |
| 197 | /// ^bb1: |
| 198 | /// async.runtime.await %token |
| 199 | /// async.runtime.drop_ref %token |
| 200 | /// cf.br ^bb2 |
| 201 | /// ^reference_counting: |
| 202 | /// async.runtime.drop_ref %token |
| 203 | /// cf.br ^bb2 |
| 204 | /// ^bb2: |
| 205 | /// return |
| 206 | /// |
| 207 | /// An exception to this rule are blocks with `async.coro.suspend` terminator, |
| 208 | /// because in Async to LLVM lowering it is guaranteed that the control flow |
| 209 | /// will jump into the resume block, and then follow into the cleanup and |
| 210 | /// suspend blocks. |
| 211 | /// |
| 212 | /// Example: |
| 213 | /// |
| 214 | /// ^entry(%value: !async.value<f32>): |
| 215 | /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> |
| 216 | /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup |
| 217 | /// ^resume: |
| 218 | /// %0 = async.runtime.load %value |
| 219 | /// cf.br ^cleanup |
| 220 | /// ^cleanup: |
| 221 | /// ... |
| 222 | /// ^suspend: |
| 223 | /// ... |
| 224 | /// |
| 225 | /// Although cleanup and suspend blocks do not have the `value` in the |
| 226 | /// `liveIn` set, it is guaranteed that execution will eventually continue in |
| 227 | /// the resume block (we never explicitly destroy coroutines). |
| 228 | LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); |
| 229 | }; |
| 230 | |
| 231 | } // namespace |
| 232 | |
| 233 | LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { |
| 234 | OpBuilder builder(value.getContext()); |
| 235 | Location loc = value.getLoc(); |
| 236 | |
| 237 | // Use liveness analysis to find the placement of `drop_ref`operation. |
| 238 | auto &liveness = getAnalysis<Liveness>(); |
| 239 | |
| 240 | // We analyse only the blocks of the region that defines the `value`, and do |
| 241 | // not check nested blocks attached to operations. |
| 242 | // |
| 243 | // By analyzing only the `definingRegion` CFG we potentially loose an |
| 244 | // opportunity to drop the reference count earlier and can extend the lifetime |
| 245 | // of reference counted value longer then it is really required. |
| 246 | // |
| 247 | // We also assume that all nested regions finish their execution before the |
| 248 | // completion of the owner operation. The only exception to this rule is |
| 249 | // `async.execute` operation, and we verify that they are lowered to the |
| 250 | // `async.runtime` operations before adding automatic reference counting. |
| 251 | Region *definingRegion = value.getParentRegion(); |
| 252 | |
| 253 | // Last users of the `value` inside all blocks where the value dies. |
| 254 | llvm::SmallSet<Operation *, 4> lastUsers; |
| 255 | |
| 256 | // Find blocks in the `definingRegion` that have users of the `value` (if |
| 257 | // there are multiple users in the block, which one will be selected is |
| 258 | // undefined). User operation might be not the actual user of the value, but |
| 259 | // the operation in the block that has a "real user" in one of the attached |
| 260 | // regions. |
| 261 | llvm::DenseMap<Block *, Operation *> usersInTheBlocks; |
| 262 | |
| 263 | for (Operation *user : value.getUsers()) { |
| 264 | Block *userBlock = user->getBlock(); |
| 265 | Block *ancestor = definingRegion->findAncestorBlockInRegion(block&: *userBlock); |
| 266 | usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(op&: *user); |
| 267 | assert(ancestor && "ancestor block must be not null" ); |
| 268 | assert(usersInTheBlocks[ancestor] && "ancestor op must be not null" ); |
| 269 | } |
| 270 | |
| 271 | // Find blocks where the `value` dies: the value is in `liveIn` set and not |
| 272 | // in the `liveOut` set. We place `drop_ref` immediately after the last use |
| 273 | // of the `value` in such regions (after handling few special cases). |
| 274 | // |
| 275 | // We do not traverse all the blocks in the `definingRegion`, because the |
| 276 | // `value` can be in the live in set only if it has users in the block, or it |
| 277 | // is defined in the block. |
| 278 | // |
| 279 | // Values with zero users (only definition) handled explicitly above. |
| 280 | for (auto &blockAndUser : usersInTheBlocks) { |
| 281 | Block *block = blockAndUser.getFirst(); |
| 282 | Operation *userInTheBlock = blockAndUser.getSecond(); |
| 283 | |
| 284 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); |
| 285 | |
| 286 | // Value must be in the live input set or defined in the block. |
| 287 | assert(blockLiveness->isLiveIn(value) || |
| 288 | blockLiveness->getBlock() == value.getParentBlock()); |
| 289 | |
| 290 | // If value is in the live out set, it means it doesn't "die" in the block. |
| 291 | if (blockLiveness->isLiveOut(value)) |
| 292 | continue; |
| 293 | |
| 294 | // At this point we proved that `value` dies in the `block`. Find the last |
| 295 | // use of the `value` inside the `block`, this is where it "dies". |
| 296 | Operation *lastUser = blockLiveness->getEndOperation(value, startOperation: userInTheBlock); |
| 297 | assert(lastUsers.count(lastUser) == 0 && "last users must be unique" ); |
| 298 | lastUsers.insert(Ptr: lastUser); |
| 299 | } |
| 300 | |
| 301 | // Process all the last users of the `value` inside each block where the value |
| 302 | // dies. |
| 303 | for (Operation *lastUser : lastUsers) { |
| 304 | // Return like operations forward reference count. |
| 305 | if (lastUser->hasTrait<OpTrait::ReturnLike>()) |
| 306 | continue; |
| 307 | |
| 308 | // We can't currently handle other types of terminators. |
| 309 | if (lastUser->hasTrait<OpTrait::IsTerminator>()) |
| 310 | return lastUser->emitError() << "async reference counting can't handle " |
| 311 | "terminators that are not ReturnLike" ; |
| 312 | |
| 313 | // Add a drop_ref immediately after the last user. |
| 314 | builder.setInsertionPointAfter(lastUser); |
| 315 | builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
| 316 | } |
| 317 | |
| 318 | return success(); |
| 319 | } |
| 320 | |
| 321 | LogicalResult |
| 322 | AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { |
| 323 | OpBuilder builder(value.getContext()); |
| 324 | Location loc = value.getLoc(); |
| 325 | |
| 326 | for (Operation *user : value.getUsers()) { |
| 327 | if (!isa<func::CallOp>(user)) |
| 328 | continue; |
| 329 | |
| 330 | // Add a reference before the function call to pass the value at `+1` |
| 331 | // reference to the function entry block. |
| 332 | builder.setInsertionPoint(user); |
| 333 | builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
| 334 | } |
| 335 | |
| 336 | return success(); |
| 337 | } |
| 338 | |
| 339 | LogicalResult |
| 340 | AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( |
| 341 | Value value) { |
| 342 | using BlockSet = llvm::SmallPtrSet<Block *, 4>; |
| 343 | |
| 344 | OpBuilder builder(value.getContext()); |
| 345 | |
| 346 | // If a block has successors with different `liveIn` property of the `value`, |
| 347 | // record block successors that do not thave the `value` in the `liveIn` set. |
| 348 | llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; |
| 349 | |
| 350 | // Use liveness analysis to find the placement of `drop_ref`operation. |
| 351 | auto &liveness = getAnalysis<Liveness>(); |
| 352 | |
| 353 | // Because we only add `drop_ref` operations to the region that defines the |
| 354 | // `value` we can only process CFG for the same region. |
| 355 | Region *definingRegion = value.getParentRegion(); |
| 356 | |
| 357 | // Collect blocks with successors with mismatching `liveIn` sets. |
| 358 | for (Block &block : definingRegion->getBlocks()) { |
| 359 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); |
| 360 | |
| 361 | // Skip the block if value is not in the `liveOut` set. |
| 362 | if (!blockLiveness || !blockLiveness->isLiveOut(value)) |
| 363 | continue; |
| 364 | |
| 365 | BlockSet liveInSuccessors; // `value` is in `liveIn` set |
| 366 | BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set |
| 367 | |
| 368 | // Collect successors that do not have `value` in the `liveIn` set. |
| 369 | for (Block *successor : block.getSuccessors()) { |
| 370 | const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); |
| 371 | if (succLiveness && succLiveness->isLiveIn(value)) |
| 372 | liveInSuccessors.insert(Ptr: successor); |
| 373 | else |
| 374 | noLiveInSuccessors.insert(Ptr: successor); |
| 375 | } |
| 376 | |
| 377 | // Block has successors with different `liveIn` property of the `value`. |
| 378 | if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) |
| 379 | divergentLivenessBlocks.try_emplace(Key: &block, Args&: noLiveInSuccessors); |
| 380 | } |
| 381 | |
| 382 | // Try to insert `dropRef` operations to handle blocks with divergent liveness |
| 383 | // in successors blocks. |
| 384 | for (auto kv : divergentLivenessBlocks) { |
| 385 | Block *block = kv.getFirst(); |
| 386 | BlockSet &successors = kv.getSecond(); |
| 387 | |
| 388 | // Coroutine suspension is a special case terminator for wich we do not |
| 389 | // need to create additional reference counting (see details above). |
| 390 | Operation *terminator = block->getTerminator(); |
| 391 | if (isa<CoroSuspendOp>(terminator)) |
| 392 | continue; |
| 393 | |
| 394 | // We only support successor blocks with empty block argument list. |
| 395 | auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; |
| 396 | if (llvm::any_of(Range&: successors, P: hasArgs)) |
| 397 | return terminator->emitOpError() |
| 398 | << "successor have different `liveIn` property of the reference " |
| 399 | "counted value" ; |
| 400 | |
| 401 | // Make sure that `dropRef` operation is called when branched into the |
| 402 | // successor block without `value` in the `liveIn` set. |
| 403 | for (Block *successor : successors) { |
| 404 | // If successor has a unique predecessor, it is safe to create `dropRef` |
| 405 | // operations directly in the successor block. |
| 406 | // |
| 407 | // Otherwise we need to create a special block for reference counting |
| 408 | // operations, and branch from it to the original successor block. |
| 409 | Block *refCountingBlock = nullptr; |
| 410 | |
| 411 | if (successor->getUniquePredecessor() == block) { |
| 412 | refCountingBlock = successor; |
| 413 | } else { |
| 414 | refCountingBlock = &successor->getParent()->emplaceBlock(); |
| 415 | refCountingBlock->moveBefore(block: successor); |
| 416 | OpBuilder builder = OpBuilder::atBlockEnd(block: refCountingBlock); |
| 417 | builder.create<cf::BranchOp>(value.getLoc(), successor); |
| 418 | } |
| 419 | |
| 420 | OpBuilder builder = OpBuilder::atBlockBegin(block: refCountingBlock); |
| 421 | builder.create<RuntimeDropRefOp>(value.getLoc(), value, |
| 422 | builder.getI64IntegerAttr(1)); |
| 423 | |
| 424 | // No need to update the terminator operation. |
| 425 | if (successor == refCountingBlock) |
| 426 | continue; |
| 427 | |
| 428 | // Update terminator `successor` block to `refCountingBlock`. |
| 429 | for (const auto &pair : llvm::enumerate(First: terminator->getSuccessors())) |
| 430 | if (pair.value() == successor) |
| 431 | terminator->setSuccessor(block: refCountingBlock, index: pair.index()); |
| 432 | } |
| 433 | } |
| 434 | |
| 435 | return success(); |
| 436 | } |
| 437 | |
| 438 | LogicalResult |
| 439 | AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { |
| 440 | // Short-circuit reference counting for values without uses. |
| 441 | if (succeeded(Result: dropRefIfNoUses(value))) |
| 442 | return success(); |
| 443 | |
| 444 | // Add `drop_ref` operations based on the liveness analysis. |
| 445 | if (failed(Result: addDropRefAfterLastUse(value))) |
| 446 | return failure(); |
| 447 | |
| 448 | // Add `add_ref` operations before function calls. |
| 449 | if (failed(Result: addAddRefBeforeFunctionCall(value))) |
| 450 | return failure(); |
| 451 | |
| 452 | // Add `drop_ref` operations to successors with divergent `value` liveness. |
| 453 | if (failed(Result: addDropRefInDivergentLivenessSuccessor(value))) |
| 454 | return failure(); |
| 455 | |
| 456 | return success(); |
| 457 | } |
| 458 | |
| 459 | void AsyncRuntimeRefCountingPass::runOnOperation() { |
| 460 | auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; |
| 461 | if (failed(walkReferenceCountedValues(getOperation(), functor))) |
| 462 | signalPassFailure(); |
| 463 | } |
| 464 | |
| 465 | //===----------------------------------------------------------------------===// |
| 466 | // Reference counting based on the user defined policy. |
| 467 | //===----------------------------------------------------------------------===// |
| 468 | |
| 469 | namespace { |
| 470 | |
| 471 | class AsyncRuntimePolicyBasedRefCountingPass |
| 472 | : public impl::AsyncRuntimePolicyBasedRefCountingPassBase< |
| 473 | AsyncRuntimePolicyBasedRefCountingPass> { |
| 474 | public: |
| 475 | AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } |
| 476 | |
| 477 | void runOnOperation() override; |
| 478 | |
| 479 | private: |
| 480 | // Adds a reference counting operations for all uses of the `value` according |
| 481 | // to the reference counting policy. |
| 482 | LogicalResult addRefCounting(Value value); |
| 483 | |
| 484 | void initializeDefaultPolicy(); |
| 485 | |
| 486 | llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; |
| 487 | }; |
| 488 | |
| 489 | } // namespace |
| 490 | |
| 491 | LogicalResult |
| 492 | AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { |
| 493 | // Short-circuit reference counting for values without uses. |
| 494 | if (succeeded(Result: dropRefIfNoUses(value))) |
| 495 | return success(); |
| 496 | |
| 497 | OpBuilder b(value.getContext()); |
| 498 | |
| 499 | // Consult the user defined policy for every value use. |
| 500 | for (OpOperand &operand : value.getUses()) { |
| 501 | Location loc = operand.getOwner()->getLoc(); |
| 502 | |
| 503 | for (auto &func : policy) { |
| 504 | FailureOr<int> refCount = func(operand); |
| 505 | if (failed(Result: refCount)) |
| 506 | return failure(); |
| 507 | |
| 508 | int cnt = *refCount; |
| 509 | |
| 510 | // Create `add_ref` operation before the operand owner. |
| 511 | if (cnt > 0) { |
| 512 | b.setInsertionPoint(operand.getOwner()); |
| 513 | b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); |
| 514 | } |
| 515 | |
| 516 | // Create `drop_ref` operation after the operand owner. |
| 517 | if (cnt < 0) { |
| 518 | b.setInsertionPointAfter(operand.getOwner()); |
| 519 | b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); |
| 520 | } |
| 521 | } |
| 522 | } |
| 523 | |
| 524 | return success(); |
| 525 | } |
| 526 | |
| 527 | void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { |
| 528 | policy.push_back(Elt: [](OpOperand &operand) -> FailureOr<int> { |
| 529 | Operation *op = operand.getOwner(); |
| 530 | Type type = operand.get().getType(); |
| 531 | |
| 532 | bool isToken = isa<TokenType>(type); |
| 533 | bool isGroup = isa<GroupType>(type); |
| 534 | bool isValue = isa<ValueType>(type); |
| 535 | |
| 536 | // Drop reference after async token or group error check (coro await). |
| 537 | if (isa<RuntimeIsErrorOp>(op)) |
| 538 | return (isToken || isGroup) ? -1 : 0; |
| 539 | |
| 540 | // Drop reference after async value load. |
| 541 | if (isa<RuntimeLoadOp>(op)) |
| 542 | return isValue ? -1 : 0; |
| 543 | |
| 544 | // Drop reference after async token added to the group. |
| 545 | if (isa<RuntimeAddToGroupOp>(op)) |
| 546 | return isToken ? -1 : 0; |
| 547 | |
| 548 | return 0; |
| 549 | }); |
| 550 | } |
| 551 | |
| 552 | void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { |
| 553 | auto functor = [&](Value value) { return addRefCounting(value); }; |
| 554 | if (failed(walkReferenceCountedValues(getOperation(), functor))) |
| 555 | signalPassFailure(); |
| 556 | } |
| 557 | |