1 | //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Optimize Async dialect reference counting operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Async/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Async/IR/Async.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "llvm/ADT/SmallSet.h" |
18 | #include "llvm/Support/Debug.h" |
19 | |
20 | namespace mlir { |
21 | #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT |
22 | #include "mlir/Dialect/Async/Passes.h.inc" |
23 | } // namespace mlir |
24 | |
25 | #define DEBUG_TYPE "async-ref-counting" |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::async; |
29 | |
30 | namespace { |
31 | |
32 | class AsyncRuntimeRefCountingOptPass |
33 | : public impl::AsyncRuntimeRefCountingOptBase< |
34 | AsyncRuntimeRefCountingOptPass> { |
35 | public: |
36 | AsyncRuntimeRefCountingOptPass() = default; |
37 | void runOnOperation() override; |
38 | |
39 | private: |
40 | LogicalResult optimizeReferenceCounting( |
41 | Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable); |
42 | }; |
43 | |
44 | } // namespace |
45 | |
46 | LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting( |
47 | Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) { |
48 | Region *definingRegion = value.getParentRegion(); |
49 | |
50 | // Find all users of the `value` inside each block, including operations that |
51 | // do not use `value` directly, but have a direct use inside nested region(s). |
52 | // |
53 | // Example: |
54 | // |
55 | // ^bb1: |
56 | // %token = ... |
57 | // scf.if %cond { |
58 | // ^bb2: |
59 | // async.runtime.await %token : !async.token |
60 | // } |
61 | // |
62 | // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1 |
63 | // (`scf.if`). |
64 | |
65 | struct BlockUsersInfo { |
66 | llvm::SmallVector<RuntimeAddRefOp, 4> addRefs; |
67 | llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs; |
68 | llvm::SmallVector<Operation *, 4> users; |
69 | }; |
70 | |
71 | llvm::DenseMap<Block *, BlockUsersInfo> blockUsers; |
72 | |
73 | auto updateBlockUsersInfo = [&](Operation *user) { |
74 | BlockUsersInfo &info = blockUsers[user->getBlock()]; |
75 | info.users.push_back(Elt: user); |
76 | |
77 | if (auto addRef = dyn_cast<RuntimeAddRefOp>(user)) |
78 | info.addRefs.push_back(addRef); |
79 | if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user)) |
80 | info.dropRefs.push_back(dropRef); |
81 | }; |
82 | |
83 | for (Operation *user : value.getUsers()) { |
84 | while (user->getParentRegion() != definingRegion) { |
85 | updateBlockUsersInfo(user); |
86 | user = user->getParentOp(); |
87 | assert(user != nullptr && "value user lies outside of the value region" ); |
88 | } |
89 | |
90 | updateBlockUsersInfo(user); |
91 | } |
92 | |
93 | // Sort all operations found in the block. |
94 | auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { |
95 | auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { |
96 | return a->isBeforeInBlock(other: b); |
97 | }; |
98 | llvm::sort(info.addRefs, isBeforeInBlock); |
99 | llvm::sort(info.dropRefs, isBeforeInBlock); |
100 | llvm::sort(C&: info.users, Comp: [&](Operation *a, Operation *b) -> bool { |
101 | return isBeforeInBlock(a, b); |
102 | }); |
103 | |
104 | return info; |
105 | }; |
106 | |
107 | // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the |
108 | // blocks that modify the reference count of the `value`. |
109 | for (auto &kv : blockUsers) { |
110 | BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); |
111 | |
112 | for (RuntimeAddRefOp addRef : info.addRefs) { |
113 | for (RuntimeDropRefOp dropRef : info.dropRefs) { |
114 | // `drop_ref` operation after the `add_ref` with matching count. |
115 | if (dropRef.getCount() != addRef.getCount() || |
116 | dropRef->isBeforeInBlock(addRef.getOperation())) |
117 | continue; |
118 | |
119 | // When reference counted value passed to a function as an argument, |
120 | // function takes ownership of +1 reference and it will drop it before |
121 | // returning. |
122 | // |
123 | // Example: |
124 | // |
125 | // %token = ... : !async.token |
126 | // |
127 | // async.runtime.add_ref %token {count = 1 : i64} : !async.token |
128 | // call @pass_token(%token: !async.token, ...) |
129 | // |
130 | // async.await %token : !async.token |
131 | // async.runtime.drop_ref %token {count = 1 : i64} : !async.token |
132 | // |
133 | // In this example if we'll cancel a pair of reference counting |
134 | // operations we might end up with a deallocated token when we'll |
135 | // reach `async.await` operation. |
136 | Operation *firstFunctionCallUser = nullptr; |
137 | Operation *lastNonFunctionCallUser = nullptr; |
138 | |
139 | for (Operation *user : info.users) { |
140 | // `user` operation lies after `addRef` ... |
141 | if (user == addRef || user->isBeforeInBlock(addRef)) |
142 | continue; |
143 | // ... and before `dropRef`. |
144 | if (user == dropRef || dropRef->isBeforeInBlock(user)) |
145 | break; |
146 | |
147 | // Find the first function call user of the reference counted value. |
148 | Operation *functionCall = dyn_cast<func::CallOp>(user); |
149 | if (functionCall && |
150 | (!firstFunctionCallUser || |
151 | functionCall->isBeforeInBlock(firstFunctionCallUser))) { |
152 | firstFunctionCallUser = functionCall; |
153 | continue; |
154 | } |
155 | |
156 | // Find the last regular user of the reference counted value. |
157 | if (!functionCall && |
158 | (!lastNonFunctionCallUser || |
159 | lastNonFunctionCallUser->isBeforeInBlock(user))) { |
160 | lastNonFunctionCallUser = user; |
161 | continue; |
162 | } |
163 | } |
164 | |
165 | // Non function call user after the function call user of the reference |
166 | // counted value. |
167 | if (firstFunctionCallUser && lastNonFunctionCallUser && |
168 | firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser)) |
169 | continue; |
170 | |
171 | // Try to cancel the pair of `add_ref` and `drop_ref` operations. |
172 | auto emplaced = cancellable.try_emplace(dropRef.getOperation(), |
173 | addRef.getOperation()); |
174 | |
175 | if (!emplaced.second) // `drop_ref` was already marked for removal |
176 | continue; // go to the next `drop_ref` |
177 | |
178 | if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref` |
179 | break; // go to the next `add_ref` |
180 | } |
181 | } |
182 | } |
183 | |
184 | return success(); |
185 | } |
186 | |
187 | void AsyncRuntimeRefCountingOptPass::runOnOperation() { |
188 | Operation *op = getOperation(); |
189 | |
190 | // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. |
191 | // |
192 | // Find all cancellable pairs of operation and erase them in the end to keep |
193 | // all iterators valid while we are walking the function operations. |
194 | llvm::SmallDenseMap<Operation *, Operation *> cancellable; |
195 | |
196 | // Optimize reference counting for values defined by block arguments. |
197 | WalkResult blockWalk = op->walk(callback: [&](Block *block) -> WalkResult { |
198 | for (BlockArgument arg : block->getArguments()) |
199 | if (isRefCounted(type: arg.getType())) |
200 | if (failed(optimizeReferenceCounting(arg, cancellable))) |
201 | return WalkResult::interrupt(); |
202 | |
203 | return WalkResult::advance(); |
204 | }); |
205 | |
206 | if (blockWalk.wasInterrupted()) |
207 | signalPassFailure(); |
208 | |
209 | // Optimize reference counting for values defined by operation results. |
210 | WalkResult opWalk = op->walk(callback: [&](Operation *op) -> WalkResult { |
211 | for (unsigned i = 0; i < op->getNumResults(); ++i) |
212 | if (isRefCounted(type: op->getResultTypes()[i])) |
213 | if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) |
214 | return WalkResult::interrupt(); |
215 | |
216 | return WalkResult::advance(); |
217 | }); |
218 | |
219 | if (opWalk.wasInterrupted()) |
220 | signalPassFailure(); |
221 | |
222 | LLVM_DEBUG({ |
223 | llvm::dbgs() << "Found " << cancellable.size() |
224 | << " cancellable reference counting operations\n" ; |
225 | }); |
226 | |
227 | // Erase all cancellable `add_ref <-> drop_ref` operation pairs. |
228 | for (auto &kv : cancellable) { |
229 | kv.first->erase(); |
230 | kv.second->erase(); |
231 | } |
232 | } |
233 | |
234 | std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() { |
235 | return std::make_unique<AsyncRuntimeRefCountingOptPass>(); |
236 | } |
237 | |