1 | //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements automatic reference counting for Async runtime |
10 | // operations and types. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Async/Passes.h" |
15 | |
16 | #include "mlir/Analysis/Liveness.h" |
17 | #include "mlir/Dialect/Async/IR/Async.h" |
18 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | #include "llvm/ADT/SmallSet.h" |
24 | |
25 | namespace mlir { |
26 | #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING |
27 | #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING |
28 | #include "mlir/Dialect/Async/Passes.h.inc" |
29 | } // namespace mlir |
30 | |
31 | #define DEBUG_TYPE "async-runtime-ref-counting" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::async; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // Utility functions shared by reference counting passes. |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | // Drop the reference count immediately if the value has no uses. |
41 | static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { |
42 | if (!value.getUses().empty()) |
43 | return failure(); |
44 | |
45 | OpBuilder b(value.getContext()); |
46 | |
47 | // Set insertion point after the operation producing a value, or at the |
48 | // beginning of the block if the value defined by the block argument. |
49 | if (Operation *op = value.getDefiningOp()) |
50 | b.setInsertionPointAfter(op); |
51 | else |
52 | b.setInsertionPointToStart(value.getParentBlock()); |
53 | |
54 | b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); |
55 | return success(); |
56 | } |
57 | |
58 | // Calls `addRefCounting` for every reference counted value defined by the |
59 | // operation `op` (block arguments and values defined in nested regions). |
60 | static LogicalResult walkReferenceCountedValues( |
61 | Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { |
62 | // Check that we do not have high level async operations in the IR because |
63 | // otherwise reference counting will produce incorrect results after high |
64 | // level async operations will be lowered to `async.runtime` |
65 | WalkResult checkNoAsyncWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
66 | if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) |
67 | return WalkResult::advance(); |
68 | |
69 | return op->emitError() |
70 | << "async operations must be lowered to async runtime operations" ; |
71 | }); |
72 | |
73 | if (checkNoAsyncWalk.wasInterrupted()) |
74 | return failure(); |
75 | |
76 | // Add reference counting to block arguments. |
77 | WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult { |
78 | for (BlockArgument arg : block->getArguments()) |
79 | if (isRefCounted(type: arg.getType())) |
80 | if (failed(result: addRefCounting(arg))) |
81 | return WalkResult::interrupt(); |
82 | |
83 | return WalkResult::advance(); |
84 | }); |
85 | |
86 | if (blockWalk.wasInterrupted()) |
87 | return failure(); |
88 | |
89 | // Add reference counting to operation results. |
90 | WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
91 | for (unsigned i = 0; i < op->getNumResults(); ++i) |
92 | if (isRefCounted(type: op->getResultTypes()[i])) |
93 | if (failed(result: addRefCounting(op->getResult(idx: i)))) |
94 | return WalkResult::interrupt(); |
95 | |
96 | return WalkResult::advance(); |
97 | }); |
98 | |
99 | if (opWalk.wasInterrupted()) |
100 | return failure(); |
101 | |
102 | return success(); |
103 | } |
104 | |
105 | //===----------------------------------------------------------------------===// |
106 | // Automatic reference counting based on the liveness analysis. |
107 | //===----------------------------------------------------------------------===// |
108 | |
109 | namespace { |
110 | |
111 | class AsyncRuntimeRefCountingPass |
112 | : public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { |
113 | public: |
114 | AsyncRuntimeRefCountingPass() = default; |
115 | void runOnOperation() override; |
116 | |
117 | private: |
118 | /// Adds an automatic reference counting to the `value`. |
119 | /// |
120 | /// All values (token, group or value) are semantically created with a |
121 | /// reference count of +1 and it is the responsibility of the async value user |
122 | /// to place the `add_ref` and `drop_ref` operations to ensure that the value |
123 | /// is destroyed after the last use. |
124 | /// |
125 | /// The function returns failure if it can't deduce the locations where |
126 | /// to place the reference counting operations. |
127 | /// |
128 | /// Async values "semantically created" when: |
129 | /// 1. Operation returns async result (e.g. `async.runtime.create`) |
130 | /// 2. Async value passed in as a block argument (or function argument, |
131 | /// because function arguments are just entry block arguments) |
132 | /// |
133 | /// Passing async value as a function argument (or block argument) does not |
134 | /// really mean that a new async value is created, it only means that the |
135 | /// caller of a function transfered ownership of `+1` reference to the callee. |
136 | /// It is convenient to think that from the callee perspective async value was |
137 | /// "created" with `+1` reference by the block argument. |
138 | /// |
139 | /// Automatic reference counting algorithm outline: |
140 | /// |
141 | /// #1 Insert `drop_ref` operations after last use of the `value`. |
142 | /// #2 Insert `add_ref` operations before functions calls with reference |
143 | /// counted `value` operand (newly created `+1` reference will be |
144 | /// transferred to the callee). |
145 | /// #3 Verify that divergent control flow does not lead to leaked reference |
146 | /// counted objects. |
147 | /// |
148 | /// Async runtime reference counting optimization pass will optimize away |
149 | /// some of the redundant `add_ref` and `drop_ref` operations inserted by this |
150 | /// strategy (see `async-runtime-ref-counting-opt`). |
151 | LogicalResult addAutomaticRefCounting(Value value); |
152 | |
153 | /// (#1) Adds the `drop_ref` operation after the last use of the `value` |
154 | /// relying on the liveness analysis. |
155 | /// |
156 | /// If the `value` is in the block `liveIn` set and it is not in the block |
157 | /// `liveOut` set, it means that it "dies" in the block. We find the last |
158 | /// use of the value in such block and: |
159 | /// |
160 | /// 1. If the last user is a `ReturnLike` operation we do nothing, because |
161 | /// it forwards the ownership to the caller. |
162 | /// 2. Otherwise we add a `drop_ref` operation immediately after the last |
163 | /// use. |
164 | LogicalResult addDropRefAfterLastUse(Value value); |
165 | |
166 | /// (#2) Adds the `add_ref` operation before the function call taking `value` |
167 | /// operand to ensure that the value passed to the function entry block |
168 | /// has a `+1` reference count. |
169 | LogicalResult addAddRefBeforeFunctionCall(Value value); |
170 | |
171 | /// (#3) Adds the `drop_ref` operation to account for successor blocks with |
172 | /// divergent `liveIn` property: `value` is not in the `liveIn` set of all |
173 | /// successor blocks. |
174 | /// |
175 | /// Example: |
176 | /// |
177 | /// ^entry: |
178 | /// %token = async.runtime.create : !async.token |
179 | /// cf.cond_br %cond, ^bb1, ^bb2 |
180 | /// ^bb1: |
181 | /// async.runtime.await %token |
182 | /// async.runtime.drop_ref %token |
183 | /// cf.br ^bb2 |
184 | /// ^bb2: |
185 | /// return |
186 | /// |
187 | /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have |
188 | /// to branch into a special "reference counting block" from the ^entry that |
189 | /// will have a `drop_ref` operation, and then branch into the ^bb2. |
190 | /// |
191 | /// After transformation: |
192 | /// |
193 | /// ^entry: |
194 | /// %token = async.runtime.create : !async.token |
195 | /// cf.cond_br %cond, ^bb1, ^reference_counting |
196 | /// ^bb1: |
197 | /// async.runtime.await %token |
198 | /// async.runtime.drop_ref %token |
199 | /// cf.br ^bb2 |
200 | /// ^reference_counting: |
201 | /// async.runtime.drop_ref %token |
202 | /// cf.br ^bb2 |
203 | /// ^bb2: |
204 | /// return |
205 | /// |
206 | /// An exception to this rule are blocks with `async.coro.suspend` terminator, |
207 | /// because in Async to LLVM lowering it is guaranteed that the control flow |
208 | /// will jump into the resume block, and then follow into the cleanup and |
209 | /// suspend blocks. |
210 | /// |
211 | /// Example: |
212 | /// |
213 | /// ^entry(%value: !async.value<f32>): |
214 | /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> |
215 | /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup |
216 | /// ^resume: |
217 | /// %0 = async.runtime.load %value |
218 | /// cf.br ^cleanup |
219 | /// ^cleanup: |
220 | /// ... |
221 | /// ^suspend: |
222 | /// ... |
223 | /// |
224 | /// Although cleanup and suspend blocks do not have the `value` in the |
225 | /// `liveIn` set, it is guaranteed that execution will eventually continue in |
226 | /// the resume block (we never explicitly destroy coroutines). |
227 | LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); |
228 | }; |
229 | |
230 | } // namespace |
231 | |
232 | LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { |
233 | OpBuilder builder(value.getContext()); |
234 | Location loc = value.getLoc(); |
235 | |
236 | // Use liveness analysis to find the placement of `drop_ref`operation. |
237 | auto &liveness = getAnalysis<Liveness>(); |
238 | |
239 | // We analyse only the blocks of the region that defines the `value`, and do |
240 | // not check nested blocks attached to operations. |
241 | // |
242 | // By analyzing only the `definingRegion` CFG we potentially loose an |
243 | // opportunity to drop the reference count earlier and can extend the lifetime |
244 | // of reference counted value longer then it is really required. |
245 | // |
246 | // We also assume that all nested regions finish their execution before the |
247 | // completion of the owner operation. The only exception to this rule is |
248 | // `async.execute` operation, and we verify that they are lowered to the |
249 | // `async.runtime` operations before adding automatic reference counting. |
250 | Region *definingRegion = value.getParentRegion(); |
251 | |
252 | // Last users of the `value` inside all blocks where the value dies. |
253 | llvm::SmallSet<Operation *, 4> lastUsers; |
254 | |
255 | // Find blocks in the `definingRegion` that have users of the `value` (if |
256 | // there are multiple users in the block, which one will be selected is |
257 | // undefined). User operation might be not the actual user of the value, but |
258 | // the operation in the block that has a "real user" in one of the attached |
259 | // regions. |
260 | llvm::DenseMap<Block *, Operation *> usersInTheBlocks; |
261 | |
262 | for (Operation *user : value.getUsers()) { |
263 | Block *userBlock = user->getBlock(); |
264 | Block *ancestor = definingRegion->findAncestorBlockInRegion(block&: *userBlock); |
265 | usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(op&: *user); |
266 | assert(ancestor && "ancestor block must be not null" ); |
267 | assert(usersInTheBlocks[ancestor] && "ancestor op must be not null" ); |
268 | } |
269 | |
270 | // Find blocks where the `value` dies: the value is in `liveIn` set and not |
271 | // in the `liveOut` set. We place `drop_ref` immediately after the last use |
272 | // of the `value` in such regions (after handling few special cases). |
273 | // |
274 | // We do not traverse all the blocks in the `definingRegion`, because the |
275 | // `value` can be in the live in set only if it has users in the block, or it |
276 | // is defined in the block. |
277 | // |
278 | // Values with zero users (only definition) handled explicitly above. |
279 | for (auto &blockAndUser : usersInTheBlocks) { |
280 | Block *block = blockAndUser.getFirst(); |
281 | Operation *userInTheBlock = blockAndUser.getSecond(); |
282 | |
283 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); |
284 | |
285 | // Value must be in the live input set or defined in the block. |
286 | assert(blockLiveness->isLiveIn(value) || |
287 | blockLiveness->getBlock() == value.getParentBlock()); |
288 | |
289 | // If value is in the live out set, it means it doesn't "die" in the block. |
290 | if (blockLiveness->isLiveOut(value)) |
291 | continue; |
292 | |
293 | // At this point we proved that `value` dies in the `block`. Find the last |
294 | // use of the `value` inside the `block`, this is where it "dies". |
295 | Operation *lastUser = blockLiveness->getEndOperation(value, startOperation: userInTheBlock); |
296 | assert(lastUsers.count(lastUser) == 0 && "last users must be unique" ); |
297 | lastUsers.insert(Ptr: lastUser); |
298 | } |
299 | |
300 | // Process all the last users of the `value` inside each block where the value |
301 | // dies. |
302 | for (Operation *lastUser : lastUsers) { |
303 | // Return like operations forward reference count. |
304 | if (lastUser->hasTrait<OpTrait::ReturnLike>()) |
305 | continue; |
306 | |
307 | // We can't currently handle other types of terminators. |
308 | if (lastUser->hasTrait<OpTrait::IsTerminator>()) |
309 | return lastUser->emitError() << "async reference counting can't handle " |
310 | "terminators that are not ReturnLike" ; |
311 | |
312 | // Add a drop_ref immediately after the last user. |
313 | builder.setInsertionPointAfter(lastUser); |
314 | builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
315 | } |
316 | |
317 | return success(); |
318 | } |
319 | |
320 | LogicalResult |
321 | AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { |
322 | OpBuilder builder(value.getContext()); |
323 | Location loc = value.getLoc(); |
324 | |
325 | for (Operation *user : value.getUsers()) { |
326 | if (!isa<func::CallOp>(user)) |
327 | continue; |
328 | |
329 | // Add a reference before the function call to pass the value at `+1` |
330 | // reference to the function entry block. |
331 | builder.setInsertionPoint(user); |
332 | builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
333 | } |
334 | |
335 | return success(); |
336 | } |
337 | |
338 | LogicalResult |
339 | AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( |
340 | Value value) { |
341 | using BlockSet = llvm::SmallPtrSet<Block *, 4>; |
342 | |
343 | OpBuilder builder(value.getContext()); |
344 | |
345 | // If a block has successors with different `liveIn` property of the `value`, |
346 | // record block successors that do not thave the `value` in the `liveIn` set. |
347 | llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; |
348 | |
349 | // Use liveness analysis to find the placement of `drop_ref`operation. |
350 | auto &liveness = getAnalysis<Liveness>(); |
351 | |
352 | // Because we only add `drop_ref` operations to the region that defines the |
353 | // `value` we can only process CFG for the same region. |
354 | Region *definingRegion = value.getParentRegion(); |
355 | |
356 | // Collect blocks with successors with mismatching `liveIn` sets. |
357 | for (Block &block : definingRegion->getBlocks()) { |
358 | const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); |
359 | |
360 | // Skip the block if value is not in the `liveOut` set. |
361 | if (!blockLiveness || !blockLiveness->isLiveOut(value)) |
362 | continue; |
363 | |
364 | BlockSet liveInSuccessors; // `value` is in `liveIn` set |
365 | BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set |
366 | |
367 | // Collect successors that do not have `value` in the `liveIn` set. |
368 | for (Block *successor : block.getSuccessors()) { |
369 | const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); |
370 | if (succLiveness && succLiveness->isLiveIn(value)) |
371 | liveInSuccessors.insert(Ptr: successor); |
372 | else |
373 | noLiveInSuccessors.insert(Ptr: successor); |
374 | } |
375 | |
376 | // Block has successors with different `liveIn` property of the `value`. |
377 | if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) |
378 | divergentLivenessBlocks.try_emplace(Key: &block, Args&: noLiveInSuccessors); |
379 | } |
380 | |
381 | // Try to insert `dropRef` operations to handle blocks with divergent liveness |
382 | // in successors blocks. |
383 | for (auto kv : divergentLivenessBlocks) { |
384 | Block *block = kv.getFirst(); |
385 | BlockSet &successors = kv.getSecond(); |
386 | |
387 | // Coroutine suspension is a special case terminator for wich we do not |
388 | // need to create additional reference counting (see details above). |
389 | Operation *terminator = block->getTerminator(); |
390 | if (isa<CoroSuspendOp>(terminator)) |
391 | continue; |
392 | |
393 | // We only support successor blocks with empty block argument list. |
394 | auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; |
395 | if (llvm::any_of(Range&: successors, P: hasArgs)) |
396 | return terminator->emitOpError() |
397 | << "successor have different `liveIn` property of the reference " |
398 | "counted value" ; |
399 | |
400 | // Make sure that `dropRef` operation is called when branched into the |
401 | // successor block without `value` in the `liveIn` set. |
402 | for (Block *successor : successors) { |
403 | // If successor has a unique predecessor, it is safe to create `dropRef` |
404 | // operations directly in the successor block. |
405 | // |
406 | // Otherwise we need to create a special block for reference counting |
407 | // operations, and branch from it to the original successor block. |
408 | Block *refCountingBlock = nullptr; |
409 | |
410 | if (successor->getUniquePredecessor() == block) { |
411 | refCountingBlock = successor; |
412 | } else { |
413 | refCountingBlock = &successor->getParent()->emplaceBlock(); |
414 | refCountingBlock->moveBefore(block: successor); |
415 | OpBuilder builder = OpBuilder::atBlockEnd(block: refCountingBlock); |
416 | builder.create<cf::BranchOp>(value.getLoc(), successor); |
417 | } |
418 | |
419 | OpBuilder builder = OpBuilder::atBlockBegin(block: refCountingBlock); |
420 | builder.create<RuntimeDropRefOp>(value.getLoc(), value, |
421 | builder.getI64IntegerAttr(1)); |
422 | |
423 | // No need to update the terminator operation. |
424 | if (successor == refCountingBlock) |
425 | continue; |
426 | |
427 | // Update terminator `successor` block to `refCountingBlock`. |
428 | for (const auto &pair : llvm::enumerate(First: terminator->getSuccessors())) |
429 | if (pair.value() == successor) |
430 | terminator->setSuccessor(block: refCountingBlock, index: pair.index()); |
431 | } |
432 | } |
433 | |
434 | return success(); |
435 | } |
436 | |
437 | LogicalResult |
438 | AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { |
439 | // Short-circuit reference counting for values without uses. |
440 | if (succeeded(result: dropRefIfNoUses(value))) |
441 | return success(); |
442 | |
443 | // Add `drop_ref` operations based on the liveness analysis. |
444 | if (failed(result: addDropRefAfterLastUse(value))) |
445 | return failure(); |
446 | |
447 | // Add `add_ref` operations before function calls. |
448 | if (failed(result: addAddRefBeforeFunctionCall(value))) |
449 | return failure(); |
450 | |
451 | // Add `drop_ref` operations to successors with divergent `value` liveness. |
452 | if (failed(result: addDropRefInDivergentLivenessSuccessor(value))) |
453 | return failure(); |
454 | |
455 | return success(); |
456 | } |
457 | |
458 | void AsyncRuntimeRefCountingPass::runOnOperation() { |
459 | auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; |
460 | if (failed(walkReferenceCountedValues(getOperation(), functor))) |
461 | signalPassFailure(); |
462 | } |
463 | |
464 | //===----------------------------------------------------------------------===// |
465 | // Reference counting based on the user defined policy. |
466 | //===----------------------------------------------------------------------===// |
467 | |
468 | namespace { |
469 | |
470 | class AsyncRuntimePolicyBasedRefCountingPass |
471 | : public impl::AsyncRuntimePolicyBasedRefCountingBase< |
472 | AsyncRuntimePolicyBasedRefCountingPass> { |
473 | public: |
474 | AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } |
475 | |
476 | void runOnOperation() override; |
477 | |
478 | private: |
479 | // Adds a reference counting operations for all uses of the `value` according |
480 | // to the reference counting policy. |
481 | LogicalResult addRefCounting(Value value); |
482 | |
483 | void initializeDefaultPolicy(); |
484 | |
485 | llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; |
486 | }; |
487 | |
488 | } // namespace |
489 | |
490 | LogicalResult |
491 | AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { |
492 | // Short-circuit reference counting for values without uses. |
493 | if (succeeded(result: dropRefIfNoUses(value))) |
494 | return success(); |
495 | |
496 | OpBuilder b(value.getContext()); |
497 | |
498 | // Consult the user defined policy for every value use. |
499 | for (OpOperand &operand : value.getUses()) { |
500 | Location loc = operand.getOwner()->getLoc(); |
501 | |
502 | for (auto &func : policy) { |
503 | FailureOr<int> refCount = func(operand); |
504 | if (failed(result: refCount)) |
505 | return failure(); |
506 | |
507 | int cnt = *refCount; |
508 | |
509 | // Create `add_ref` operation before the operand owner. |
510 | if (cnt > 0) { |
511 | b.setInsertionPoint(operand.getOwner()); |
512 | b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); |
513 | } |
514 | |
515 | // Create `drop_ref` operation after the operand owner. |
516 | if (cnt < 0) { |
517 | b.setInsertionPointAfter(operand.getOwner()); |
518 | b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); |
519 | } |
520 | } |
521 | } |
522 | |
523 | return success(); |
524 | } |
525 | |
526 | void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { |
527 | policy.push_back(Elt: [](OpOperand &operand) -> FailureOr<int> { |
528 | Operation *op = operand.getOwner(); |
529 | Type type = operand.get().getType(); |
530 | |
531 | bool isToken = isa<TokenType>(type); |
532 | bool isGroup = isa<GroupType>(type); |
533 | bool isValue = isa<ValueType>(type); |
534 | |
535 | // Drop reference after async token or group error check (coro await). |
536 | if (dyn_cast<RuntimeIsErrorOp>(op)) |
537 | return (isToken || isGroup) ? -1 : 0; |
538 | |
539 | // Drop reference after async value load. |
540 | if (dyn_cast<RuntimeLoadOp>(op)) |
541 | return isValue ? -1 : 0; |
542 | |
543 | // Drop reference after async token added to the group. |
544 | if (dyn_cast<RuntimeAddToGroupOp>(op)) |
545 | return isToken ? -1 : 0; |
546 | |
547 | return 0; |
548 | }); |
549 | } |
550 | |
551 | void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { |
552 | auto functor = [&](Value value) { return addRefCounting(value); }; |
553 | if (failed(walkReferenceCountedValues(getOperation(), functor))) |
554 | signalPassFailure(); |
555 | } |
556 | |
557 | //----------------------------------------------------------------------------// |
558 | |
559 | std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { |
560 | return std::make_unique<AsyncRuntimeRefCountingPass>(); |
561 | } |
562 | |
563 | std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { |
564 | return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); |
565 | } |
566 | |