1//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Optimize Async dialect reference counting operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Async/Passes.h"
14
15#include "mlir/Dialect/Async/IR/Async.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "llvm/ADT/SmallSet.h"
18#include "llvm/Support/Debug.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
22#include "mlir/Dialect/Async/Passes.h.inc"
23} // namespace mlir
24
25#define DEBUG_TYPE "async-ref-counting"
26
27using namespace mlir;
28using namespace mlir::async;
29
30namespace {
31
32class AsyncRuntimeRefCountingOptPass
33 : public impl::AsyncRuntimeRefCountingOptBase<
34 AsyncRuntimeRefCountingOptPass> {
35public:
36 AsyncRuntimeRefCountingOptPass() = default;
37 void runOnOperation() override;
38
39private:
40 LogicalResult optimizeReferenceCounting(
41 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
42};
43
44} // namespace
45
46LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
48 Region *definingRegion = value.getParentRegion();
49
50 // Find all users of the `value` inside each block, including operations that
51 // do not use `value` directly, but have a direct use inside nested region(s).
52 //
53 // Example:
54 //
55 // ^bb1:
56 // %token = ...
57 // scf.if %cond {
58 // ^bb2:
59 // async.runtime.await %token : !async.token
60 // }
61 //
62 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
63 // (`scf.if`).
64
65 struct BlockUsersInfo {
66 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
67 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
68 llvm::SmallVector<Operation *, 4> users;
69 };
70
71 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72
73 auto updateBlockUsersInfo = [&](Operation *user) {
74 BlockUsersInfo &info = blockUsers[user->getBlock()];
75 info.users.push_back(Elt: user);
76
77 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78 info.addRefs.push_back(addRef);
79 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80 info.dropRefs.push_back(dropRef);
81 };
82
83 for (Operation *user : value.getUsers()) {
84 while (user->getParentRegion() != definingRegion) {
85 updateBlockUsersInfo(user);
86 user = user->getParentOp();
87 assert(user != nullptr && "value user lies outside of the value region");
88 }
89
90 updateBlockUsersInfo(user);
91 }
92
93 // Sort all operations found in the block.
94 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
95 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
96 return a->isBeforeInBlock(other: b);
97 };
98 llvm::sort(info.addRefs, isBeforeInBlock);
99 llvm::sort(info.dropRefs, isBeforeInBlock);
100 llvm::sort(C&: info.users, Comp: [&](Operation *a, Operation *b) -> bool {
101 return isBeforeInBlock(a, b);
102 });
103
104 return info;
105 };
106
107 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
108 // blocks that modify the reference count of the `value`.
109 for (auto &kv : blockUsers) {
110 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111
112 for (RuntimeAddRefOp addRef : info.addRefs) {
113 for (RuntimeDropRefOp dropRef : info.dropRefs) {
114 // `drop_ref` operation after the `add_ref` with matching count.
115 if (dropRef.getCount() != addRef.getCount() ||
116 dropRef->isBeforeInBlock(addRef.getOperation()))
117 continue;
118
119 // When reference counted value passed to a function as an argument,
120 // function takes ownership of +1 reference and it will drop it before
121 // returning.
122 //
123 // Example:
124 //
125 // %token = ... : !async.token
126 //
127 // async.runtime.add_ref %token {count = 1 : i64} : !async.token
128 // call @pass_token(%token: !async.token, ...)
129 //
130 // async.await %token : !async.token
131 // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
132 //
133 // In this example if we'll cancel a pair of reference counting
134 // operations we might end up with a deallocated token when we'll
135 // reach `async.await` operation.
136 Operation *firstFunctionCallUser = nullptr;
137 Operation *lastNonFunctionCallUser = nullptr;
138
139 for (Operation *user : info.users) {
140 // `user` operation lies after `addRef` ...
141 if (user == addRef || user->isBeforeInBlock(addRef))
142 continue;
143 // ... and before `dropRef`.
144 if (user == dropRef || dropRef->isBeforeInBlock(user))
145 break;
146
147 // Find the first function call user of the reference counted value.
148 Operation *functionCall = dyn_cast<func::CallOp>(user);
149 if (functionCall &&
150 (!firstFunctionCallUser ||
151 functionCall->isBeforeInBlock(firstFunctionCallUser))) {
152 firstFunctionCallUser = functionCall;
153 continue;
154 }
155
156 // Find the last regular user of the reference counted value.
157 if (!functionCall &&
158 (!lastNonFunctionCallUser ||
159 lastNonFunctionCallUser->isBeforeInBlock(user))) {
160 lastNonFunctionCallUser = user;
161 continue;
162 }
163 }
164
165 // Non function call user after the function call user of the reference
166 // counted value.
167 if (firstFunctionCallUser && lastNonFunctionCallUser &&
168 firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
169 continue;
170
171 // Try to cancel the pair of `add_ref` and `drop_ref` operations.
172 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173 addRef.getOperation());
174
175 if (!emplaced.second) // `drop_ref` was already marked for removal
176 continue; // go to the next `drop_ref`
177
178 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
179 break; // go to the next `add_ref`
180 }
181 }
182 }
183
184 return success();
185}
186
187void AsyncRuntimeRefCountingOptPass::runOnOperation() {
188 Operation *op = getOperation();
189
190 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
191 //
192 // Find all cancellable pairs of operation and erase them in the end to keep
193 // all iterators valid while we are walking the function operations.
194 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
195
196 // Optimize reference counting for values defined by block arguments.
197 WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult {
198 for (BlockArgument arg : block->getArguments())
199 if (isRefCounted(type: arg.getType()))
200 if (failed(optimizeReferenceCounting(arg, cancellable)))
201 return WalkResult::interrupt();
202
203 return WalkResult::advance();
204 });
205
206 if (blockWalk.wasInterrupted())
207 signalPassFailure();
208
209 // Optimize reference counting for values defined by operation results.
210 WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult {
211 for (unsigned i = 0; i < op->getNumResults(); ++i)
212 if (isRefCounted(type: op->getResultTypes()[i]))
213 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
214 return WalkResult::interrupt();
215
216 return WalkResult::advance();
217 });
218
219 if (opWalk.wasInterrupted())
220 signalPassFailure();
221
222 LLVM_DEBUG({
223 llvm::dbgs() << "Found " << cancellable.size()
224 << " cancellable reference counting operations\n";
225 });
226
227 // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
228 for (auto &kv : cancellable) {
229 kv.first->erase();
230 kv.second->erase();
231 }
232}
233
234std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
235 return std::make_unique<AsyncRuntimeRefCountingOptPass>();
236}
237

source code of mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp