1//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements automatic reference counting for Async runtime
10// operations and types.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Async/Passes.h"
15
16#include "mlir/Analysis/Liveness.h"
17#include "mlir/Dialect/Async/IR/Async.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/IR/ImplicitLocOpBuilder.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include "llvm/ADT/SmallSet.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING
27#define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING
28#include "mlir/Dialect/Async/Passes.h.inc"
29} // namespace mlir
30
31#define DEBUG_TYPE "async-runtime-ref-counting"
32
33using namespace mlir;
34using namespace mlir::async;
35
36//===----------------------------------------------------------------------===//
37// Utility functions shared by reference counting passes.
38//===----------------------------------------------------------------------===//
39
40// Drop the reference count immediately if the value has no uses.
41static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
42 if (!value.getUses().empty())
43 return failure();
44
45 OpBuilder b(value.getContext());
46
47 // Set insertion point after the operation producing a value, or at the
48 // beginning of the block if the value defined by the block argument.
49 if (Operation *op = value.getDefiningOp())
50 b.setInsertionPointAfter(op);
51 else
52 b.setInsertionPointToStart(value.getParentBlock());
53
54 b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
55 return success();
56}
57
58// Calls `addRefCounting` for every reference counted value defined by the
59// operation `op` (block arguments and values defined in nested regions).
60static LogicalResult walkReferenceCountedValues(
61 Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
62 // Check that we do not have high level async operations in the IR because
63 // otherwise reference counting will produce incorrect results after high
64 // level async operations will be lowered to `async.runtime`
65 WalkResult checkNoAsyncWalk = op->walk(callback: [&](Operation *op) -> WalkResult {
66 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
67 return WalkResult::advance();
68
69 return op->emitError()
70 << "async operations must be lowered to async runtime operations";
71 });
72
73 if (checkNoAsyncWalk.wasInterrupted())
74 return failure();
75
76 // Add reference counting to block arguments.
77 WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult {
78 for (BlockArgument arg : block->getArguments())
79 if (isRefCounted(type: arg.getType()))
80 if (failed(result: addRefCounting(arg)))
81 return WalkResult::interrupt();
82
83 return WalkResult::advance();
84 });
85
86 if (blockWalk.wasInterrupted())
87 return failure();
88
89 // Add reference counting to operation results.
90 WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult {
91 for (unsigned i = 0; i < op->getNumResults(); ++i)
92 if (isRefCounted(type: op->getResultTypes()[i]))
93 if (failed(result: addRefCounting(op->getResult(idx: i))))
94 return WalkResult::interrupt();
95
96 return WalkResult::advance();
97 });
98
99 if (opWalk.wasInterrupted())
100 return failure();
101
102 return success();
103}
104
105//===----------------------------------------------------------------------===//
106// Automatic reference counting based on the liveness analysis.
107//===----------------------------------------------------------------------===//
108
109namespace {
110
111class AsyncRuntimeRefCountingPass
112 : public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
113public:
114 AsyncRuntimeRefCountingPass() = default;
115 void runOnOperation() override;
116
117private:
118 /// Adds an automatic reference counting to the `value`.
119 ///
120 /// All values (token, group or value) are semantically created with a
121 /// reference count of +1 and it is the responsibility of the async value user
122 /// to place the `add_ref` and `drop_ref` operations to ensure that the value
123 /// is destroyed after the last use.
124 ///
125 /// The function returns failure if it can't deduce the locations where
126 /// to place the reference counting operations.
127 ///
128 /// Async values "semantically created" when:
129 /// 1. Operation returns async result (e.g. `async.runtime.create`)
130 /// 2. Async value passed in as a block argument (or function argument,
131 /// because function arguments are just entry block arguments)
132 ///
133 /// Passing async value as a function argument (or block argument) does not
134 /// really mean that a new async value is created, it only means that the
135 /// caller of a function transfered ownership of `+1` reference to the callee.
136 /// It is convenient to think that from the callee perspective async value was
137 /// "created" with `+1` reference by the block argument.
138 ///
139 /// Automatic reference counting algorithm outline:
140 ///
141 /// #1 Insert `drop_ref` operations after last use of the `value`.
142 /// #2 Insert `add_ref` operations before functions calls with reference
143 /// counted `value` operand (newly created `+1` reference will be
144 /// transferred to the callee).
145 /// #3 Verify that divergent control flow does not lead to leaked reference
146 /// counted objects.
147 ///
148 /// Async runtime reference counting optimization pass will optimize away
149 /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
150 /// strategy (see `async-runtime-ref-counting-opt`).
151 LogicalResult addAutomaticRefCounting(Value value);
152
153 /// (#1) Adds the `drop_ref` operation after the last use of the `value`
154 /// relying on the liveness analysis.
155 ///
156 /// If the `value` is in the block `liveIn` set and it is not in the block
157 /// `liveOut` set, it means that it "dies" in the block. We find the last
158 /// use of the value in such block and:
159 ///
160 /// 1. If the last user is a `ReturnLike` operation we do nothing, because
161 /// it forwards the ownership to the caller.
162 /// 2. Otherwise we add a `drop_ref` operation immediately after the last
163 /// use.
164 LogicalResult addDropRefAfterLastUse(Value value);
165
166 /// (#2) Adds the `add_ref` operation before the function call taking `value`
167 /// operand to ensure that the value passed to the function entry block
168 /// has a `+1` reference count.
169 LogicalResult addAddRefBeforeFunctionCall(Value value);
170
171 /// (#3) Adds the `drop_ref` operation to account for successor blocks with
172 /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
173 /// successor blocks.
174 ///
175 /// Example:
176 ///
177 /// ^entry:
178 /// %token = async.runtime.create : !async.token
179 /// cf.cond_br %cond, ^bb1, ^bb2
180 /// ^bb1:
181 /// async.runtime.await %token
182 /// async.runtime.drop_ref %token
183 /// cf.br ^bb2
184 /// ^bb2:
185 /// return
186 ///
187 /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
188 /// to branch into a special "reference counting block" from the ^entry that
189 /// will have a `drop_ref` operation, and then branch into the ^bb2.
190 ///
191 /// After transformation:
192 ///
193 /// ^entry:
194 /// %token = async.runtime.create : !async.token
195 /// cf.cond_br %cond, ^bb1, ^reference_counting
196 /// ^bb1:
197 /// async.runtime.await %token
198 /// async.runtime.drop_ref %token
199 /// cf.br ^bb2
200 /// ^reference_counting:
201 /// async.runtime.drop_ref %token
202 /// cf.br ^bb2
203 /// ^bb2:
204 /// return
205 ///
206 /// An exception to this rule are blocks with `async.coro.suspend` terminator,
207 /// because in Async to LLVM lowering it is guaranteed that the control flow
208 /// will jump into the resume block, and then follow into the cleanup and
209 /// suspend blocks.
210 ///
211 /// Example:
212 ///
213 /// ^entry(%value: !async.value<f32>):
214 /// async.runtime.await_and_resume %value, %hdl : !async.value<f32>
215 /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
216 /// ^resume:
217 /// %0 = async.runtime.load %value
218 /// cf.br ^cleanup
219 /// ^cleanup:
220 /// ...
221 /// ^suspend:
222 /// ...
223 ///
224 /// Although cleanup and suspend blocks do not have the `value` in the
225 /// `liveIn` set, it is guaranteed that execution will eventually continue in
226 /// the resume block (we never explicitly destroy coroutines).
227 LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
228};
229
230} // namespace
231
232LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
233 OpBuilder builder(value.getContext());
234 Location loc = value.getLoc();
235
236 // Use liveness analysis to find the placement of `drop_ref`operation.
237 auto &liveness = getAnalysis<Liveness>();
238
239 // We analyse only the blocks of the region that defines the `value`, and do
240 // not check nested blocks attached to operations.
241 //
242 // By analyzing only the `definingRegion` CFG we potentially loose an
243 // opportunity to drop the reference count earlier and can extend the lifetime
244 // of reference counted value longer then it is really required.
245 //
246 // We also assume that all nested regions finish their execution before the
247 // completion of the owner operation. The only exception to this rule is
248 // `async.execute` operation, and we verify that they are lowered to the
249 // `async.runtime` operations before adding automatic reference counting.
250 Region *definingRegion = value.getParentRegion();
251
252 // Last users of the `value` inside all blocks where the value dies.
253 llvm::SmallSet<Operation *, 4> lastUsers;
254
255 // Find blocks in the `definingRegion` that have users of the `value` (if
256 // there are multiple users in the block, which one will be selected is
257 // undefined). User operation might be not the actual user of the value, but
258 // the operation in the block that has a "real user" in one of the attached
259 // regions.
260 llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
261
262 for (Operation *user : value.getUsers()) {
263 Block *userBlock = user->getBlock();
264 Block *ancestor = definingRegion->findAncestorBlockInRegion(block&: *userBlock);
265 usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(op&: *user);
266 assert(ancestor && "ancestor block must be not null");
267 assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
268 }
269
270 // Find blocks where the `value` dies: the value is in `liveIn` set and not
271 // in the `liveOut` set. We place `drop_ref` immediately after the last use
272 // of the `value` in such regions (after handling few special cases).
273 //
274 // We do not traverse all the blocks in the `definingRegion`, because the
275 // `value` can be in the live in set only if it has users in the block, or it
276 // is defined in the block.
277 //
278 // Values with zero users (only definition) handled explicitly above.
279 for (auto &blockAndUser : usersInTheBlocks) {
280 Block *block = blockAndUser.getFirst();
281 Operation *userInTheBlock = blockAndUser.getSecond();
282
283 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
284
285 // Value must be in the live input set or defined in the block.
286 assert(blockLiveness->isLiveIn(value) ||
287 blockLiveness->getBlock() == value.getParentBlock());
288
289 // If value is in the live out set, it means it doesn't "die" in the block.
290 if (blockLiveness->isLiveOut(value))
291 continue;
292
293 // At this point we proved that `value` dies in the `block`. Find the last
294 // use of the `value` inside the `block`, this is where it "dies".
295 Operation *lastUser = blockLiveness->getEndOperation(value, startOperation: userInTheBlock);
296 assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
297 lastUsers.insert(Ptr: lastUser);
298 }
299
300 // Process all the last users of the `value` inside each block where the value
301 // dies.
302 for (Operation *lastUser : lastUsers) {
303 // Return like operations forward reference count.
304 if (lastUser->hasTrait<OpTrait::ReturnLike>())
305 continue;
306
307 // We can't currently handle other types of terminators.
308 if (lastUser->hasTrait<OpTrait::IsTerminator>())
309 return lastUser->emitError() << "async reference counting can't handle "
310 "terminators that are not ReturnLike";
311
312 // Add a drop_ref immediately after the last user.
313 builder.setInsertionPointAfter(lastUser);
314 builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
315 }
316
317 return success();
318}
319
320LogicalResult
321AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
322 OpBuilder builder(value.getContext());
323 Location loc = value.getLoc();
324
325 for (Operation *user : value.getUsers()) {
326 if (!isa<func::CallOp>(user))
327 continue;
328
329 // Add a reference before the function call to pass the value at `+1`
330 // reference to the function entry block.
331 builder.setInsertionPoint(user);
332 builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
333 }
334
335 return success();
336}
337
338LogicalResult
339AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
340 Value value) {
341 using BlockSet = llvm::SmallPtrSet<Block *, 4>;
342
343 OpBuilder builder(value.getContext());
344
345 // If a block has successors with different `liveIn` property of the `value`,
346 // record block successors that do not thave the `value` in the `liveIn` set.
347 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
348
349 // Use liveness analysis to find the placement of `drop_ref`operation.
350 auto &liveness = getAnalysis<Liveness>();
351
352 // Because we only add `drop_ref` operations to the region that defines the
353 // `value` we can only process CFG for the same region.
354 Region *definingRegion = value.getParentRegion();
355
356 // Collect blocks with successors with mismatching `liveIn` sets.
357 for (Block &block : definingRegion->getBlocks()) {
358 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
359
360 // Skip the block if value is not in the `liveOut` set.
361 if (!blockLiveness || !blockLiveness->isLiveOut(value))
362 continue;
363
364 BlockSet liveInSuccessors; // `value` is in `liveIn` set
365 BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
366
367 // Collect successors that do not have `value` in the `liveIn` set.
368 for (Block *successor : block.getSuccessors()) {
369 const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
370 if (succLiveness && succLiveness->isLiveIn(value))
371 liveInSuccessors.insert(Ptr: successor);
372 else
373 noLiveInSuccessors.insert(Ptr: successor);
374 }
375
376 // Block has successors with different `liveIn` property of the `value`.
377 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
378 divergentLivenessBlocks.try_emplace(Key: &block, Args&: noLiveInSuccessors);
379 }
380
381 // Try to insert `dropRef` operations to handle blocks with divergent liveness
382 // in successors blocks.
383 for (auto kv : divergentLivenessBlocks) {
384 Block *block = kv.getFirst();
385 BlockSet &successors = kv.getSecond();
386
387 // Coroutine suspension is a special case terminator for wich we do not
388 // need to create additional reference counting (see details above).
389 Operation *terminator = block->getTerminator();
390 if (isa<CoroSuspendOp>(terminator))
391 continue;
392
393 // We only support successor blocks with empty block argument list.
394 auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
395 if (llvm::any_of(Range&: successors, P: hasArgs))
396 return terminator->emitOpError()
397 << "successor have different `liveIn` property of the reference "
398 "counted value";
399
400 // Make sure that `dropRef` operation is called when branched into the
401 // successor block without `value` in the `liveIn` set.
402 for (Block *successor : successors) {
403 // If successor has a unique predecessor, it is safe to create `dropRef`
404 // operations directly in the successor block.
405 //
406 // Otherwise we need to create a special block for reference counting
407 // operations, and branch from it to the original successor block.
408 Block *refCountingBlock = nullptr;
409
410 if (successor->getUniquePredecessor() == block) {
411 refCountingBlock = successor;
412 } else {
413 refCountingBlock = &successor->getParent()->emplaceBlock();
414 refCountingBlock->moveBefore(block: successor);
415 OpBuilder builder = OpBuilder::atBlockEnd(block: refCountingBlock);
416 builder.create<cf::BranchOp>(value.getLoc(), successor);
417 }
418
419 OpBuilder builder = OpBuilder::atBlockBegin(block: refCountingBlock);
420 builder.create<RuntimeDropRefOp>(value.getLoc(), value,
421 builder.getI64IntegerAttr(1));
422
423 // No need to update the terminator operation.
424 if (successor == refCountingBlock)
425 continue;
426
427 // Update terminator `successor` block to `refCountingBlock`.
428 for (const auto &pair : llvm::enumerate(First: terminator->getSuccessors()))
429 if (pair.value() == successor)
430 terminator->setSuccessor(block: refCountingBlock, index: pair.index());
431 }
432 }
433
434 return success();
435}
436
437LogicalResult
438AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
439 // Short-circuit reference counting for values without uses.
440 if (succeeded(result: dropRefIfNoUses(value)))
441 return success();
442
443 // Add `drop_ref` operations based on the liveness analysis.
444 if (failed(result: addDropRefAfterLastUse(value)))
445 return failure();
446
447 // Add `add_ref` operations before function calls.
448 if (failed(result: addAddRefBeforeFunctionCall(value)))
449 return failure();
450
451 // Add `drop_ref` operations to successors with divergent `value` liveness.
452 if (failed(result: addDropRefInDivergentLivenessSuccessor(value)))
453 return failure();
454
455 return success();
456}
457
458void AsyncRuntimeRefCountingPass::runOnOperation() {
459 auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
460 if (failed(walkReferenceCountedValues(getOperation(), functor)))
461 signalPassFailure();
462}
463
464//===----------------------------------------------------------------------===//
465// Reference counting based on the user defined policy.
466//===----------------------------------------------------------------------===//
467
468namespace {
469
470class AsyncRuntimePolicyBasedRefCountingPass
471 : public impl::AsyncRuntimePolicyBasedRefCountingBase<
472 AsyncRuntimePolicyBasedRefCountingPass> {
473public:
474 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
475
476 void runOnOperation() override;
477
478private:
479 // Adds a reference counting operations for all uses of the `value` according
480 // to the reference counting policy.
481 LogicalResult addRefCounting(Value value);
482
483 void initializeDefaultPolicy();
484
485 llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
486};
487
488} // namespace
489
490LogicalResult
491AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
492 // Short-circuit reference counting for values without uses.
493 if (succeeded(result: dropRefIfNoUses(value)))
494 return success();
495
496 OpBuilder b(value.getContext());
497
498 // Consult the user defined policy for every value use.
499 for (OpOperand &operand : value.getUses()) {
500 Location loc = operand.getOwner()->getLoc();
501
502 for (auto &func : policy) {
503 FailureOr<int> refCount = func(operand);
504 if (failed(result: refCount))
505 return failure();
506
507 int cnt = *refCount;
508
509 // Create `add_ref` operation before the operand owner.
510 if (cnt > 0) {
511 b.setInsertionPoint(operand.getOwner());
512 b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
513 }
514
515 // Create `drop_ref` operation after the operand owner.
516 if (cnt < 0) {
517 b.setInsertionPointAfter(operand.getOwner());
518 b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
519 }
520 }
521 }
522
523 return success();
524}
525
526void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
527 policy.push_back(Elt: [](OpOperand &operand) -> FailureOr<int> {
528 Operation *op = operand.getOwner();
529 Type type = operand.get().getType();
530
531 bool isToken = isa<TokenType>(type);
532 bool isGroup = isa<GroupType>(type);
533 bool isValue = isa<ValueType>(type);
534
535 // Drop reference after async token or group error check (coro await).
536 if (dyn_cast<RuntimeIsErrorOp>(op))
537 return (isToken || isGroup) ? -1 : 0;
538
539 // Drop reference after async value load.
540 if (dyn_cast<RuntimeLoadOp>(op))
541 return isValue ? -1 : 0;
542
543 // Drop reference after async token added to the group.
544 if (dyn_cast<RuntimeAddToGroupOp>(op))
545 return isToken ? -1 : 0;
546
547 return 0;
548 });
549}
550
551void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
552 auto functor = [&](Value value) { return addRefCounting(value); };
553 if (failed(walkReferenceCountedValues(getOperation(), functor)))
554 signalPassFailure();
555}
556
557//----------------------------------------------------------------------------//
558
559std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
560 return std::make_unique<AsyncRuntimeRefCountingPass>();
561}
562
563std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
564 return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
565}
566

source code of mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp