1 | //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Inliner that uses a basic inlining |
10 | // algorithm that operates bottom up over the Strongly Connect Components(SCCs) |
11 | // of the CallGraph. This enables a more incremental propagation of inlining |
12 | // decisions from the leafs to the roots of the callgraph. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Transforms/Inliner.h" |
17 | #include "mlir/IR/Threading.h" |
18 | #include "mlir/Interfaces/CallInterfaces.h" |
19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Support/DebugStringHelper.h" |
22 | #include "mlir/Transforms/InliningUtils.h" |
23 | #include "llvm/ADT/SCCIterator.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/SmallPtrSet.h" |
26 | #include "llvm/Support/Debug.h" |
27 | |
28 | #define DEBUG_TYPE "inlining" |
29 | |
30 | using namespace mlir; |
31 | |
32 | using ResolvedCall = Inliner::ResolvedCall; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // Symbol Use Tracking |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | /// Walk all of the used symbol callgraph nodes referenced with the given op. |
39 | static void walkReferencedSymbolNodes( |
40 | Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, |
41 | DenseMap<Attribute, CallGraphNode *> &resolvedRefs, |
42 | function_ref<void(CallGraphNode *, Operation *)> callback) { |
43 | auto symbolUses = SymbolTable::getSymbolUses(from: op); |
44 | assert(symbolUses && "expected uses to be valid" ); |
45 | |
46 | Operation *symbolTableOp = op->getParentOp(); |
47 | for (const SymbolTable::SymbolUse &use : *symbolUses) { |
48 | auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr}); |
49 | CallGraphNode *&node = refIt.first->second; |
50 | |
51 | // If this is the first instance of this reference, try to resolve a |
52 | // callgraph node for it. |
53 | if (refIt.second) { |
54 | auto *symbolOp = symbolTable.lookupNearestSymbolFrom(from: symbolTableOp, |
55 | symbol: use.getSymbolRef()); |
56 | auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); |
57 | if (!callableOp) |
58 | continue; |
59 | node = cg.lookupNode(region: callableOp.getCallableRegion()); |
60 | } |
61 | if (node) |
62 | callback(node, use.getUser()); |
63 | } |
64 | } |
65 | |
66 | //===----------------------------------------------------------------------===// |
67 | // CGUseList |
68 | |
69 | namespace { |
70 | /// This struct tracks the uses of callgraph nodes that can be dropped when |
71 | /// use_empty. It directly tracks and manages a use-list for all of the |
72 | /// call-graph nodes. This is necessary because many callgraph nodes are |
73 | /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` |
74 | /// class. |
75 | struct CGUseList { |
76 | /// This struct tracks the uses of callgraph nodes within a specific |
77 | /// operation. |
78 | struct CGUser { |
79 | /// Any nodes referenced in the top-level attribute list of this user. We |
80 | /// use a set here because the number of references does not matter. |
81 | DenseSet<CallGraphNode *> topLevelUses; |
82 | |
83 | /// Uses of nodes referenced by nested operations. |
84 | DenseMap<CallGraphNode *, int> innerUses; |
85 | }; |
86 | |
87 | CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); |
88 | |
89 | /// Drop uses of nodes referred to by the given call operation that resides |
90 | /// within 'userNode'. |
91 | void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); |
92 | |
93 | /// Remove the given node from the use list. |
94 | void eraseNode(CallGraphNode *node); |
95 | |
96 | /// Returns true if the given callgraph node has no uses and can be pruned. |
97 | bool isDead(CallGraphNode *node) const; |
98 | |
99 | /// Returns true if the given callgraph node has a single use and can be |
100 | /// discarded. |
101 | bool hasOneUseAndDiscardable(CallGraphNode *node) const; |
102 | |
103 | /// Recompute the uses held by the given callgraph node. |
104 | void recomputeUses(CallGraphNode *node, CallGraph &cg); |
105 | |
106 | /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy |
107 | /// of 'lhs' into 'rhs'. |
108 | void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); |
109 | |
110 | private: |
111 | /// Decrement the uses of discardable nodes referenced by the given user. |
112 | void decrementDiscardableUses(CGUser &uses); |
113 | |
114 | /// A mapping between a discardable callgraph node (that is a symbol) and the |
115 | /// number of uses for this node. |
116 | DenseMap<CallGraphNode *, int> discardableSymNodeUses; |
117 | |
118 | /// A mapping between a callgraph node and the symbol callgraph nodes that it |
119 | /// uses. |
120 | DenseMap<CallGraphNode *, CGUser> nodeUses; |
121 | |
122 | /// A symbol table to use when resolving call lookups. |
123 | SymbolTableCollection &symbolTable; |
124 | }; |
125 | } // namespace |
126 | |
127 | CGUseList::CGUseList(Operation *op, CallGraph &cg, |
128 | SymbolTableCollection &symbolTable) |
129 | : symbolTable(symbolTable) { |
130 | /// A set of callgraph nodes that are always known to be live during inlining. |
131 | DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; |
132 | |
133 | // Walk each of the symbol tables looking for discardable callgraph nodes. |
134 | auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { |
135 | for (Operation &op : symbolTableOp->getRegion(index: 0).getOps()) { |
136 | // If this is a callgraph operation, check to see if it is discardable. |
137 | if (auto callable = dyn_cast<CallableOpInterface>(&op)) { |
138 | if (auto *node = cg.lookupNode(callable.getCallableRegion())) { |
139 | SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); |
140 | if (symbol && (allUsesVisible || symbol.isPrivate()) && |
141 | symbol.canDiscardOnUseEmpty()) { |
142 | discardableSymNodeUses.try_emplace(node, 0); |
143 | } |
144 | continue; |
145 | } |
146 | } |
147 | // Otherwise, check for any referenced nodes. These will be always-live. |
148 | walkReferencedSymbolNodes(op: &op, cg, symbolTable, resolvedRefs&: alwaysLiveNodes, |
149 | callback: [](CallGraphNode *, Operation *) {}); |
150 | } |
151 | }; |
152 | SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), |
153 | callback: walkFn); |
154 | |
155 | // Drop the use information for any discardable nodes that are always live. |
156 | for (auto &it : alwaysLiveNodes) |
157 | discardableSymNodeUses.erase(Val: it.second); |
158 | |
159 | // Compute the uses for each of the callable nodes in the graph. |
160 | for (CallGraphNode *node : cg) |
161 | recomputeUses(node, cg); |
162 | } |
163 | |
164 | void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, |
165 | CallGraph &cg) { |
166 | auto &userRefs = nodeUses[userNode].innerUses; |
167 | auto walkFn = [&](CallGraphNode *node, Operation *user) { |
168 | auto parentIt = userRefs.find(Val: node); |
169 | if (parentIt == userRefs.end()) |
170 | return; |
171 | --parentIt->second; |
172 | --discardableSymNodeUses[node]; |
173 | }; |
174 | DenseMap<Attribute, CallGraphNode *> resolvedRefs; |
175 | walkReferencedSymbolNodes(op: callOp, cg, symbolTable, resolvedRefs, callback: walkFn); |
176 | } |
177 | |
178 | void CGUseList::eraseNode(CallGraphNode *node) { |
179 | // Drop all child nodes. |
180 | for (auto &edge : *node) |
181 | if (edge.isChild()) |
182 | eraseNode(node: edge.getTarget()); |
183 | |
184 | // Drop the uses held by this node and erase it. |
185 | auto useIt = nodeUses.find(Val: node); |
186 | assert(useIt != nodeUses.end() && "expected node to be valid" ); |
187 | decrementDiscardableUses(uses&: useIt->getSecond()); |
188 | nodeUses.erase(I: useIt); |
189 | discardableSymNodeUses.erase(Val: node); |
190 | } |
191 | |
192 | bool CGUseList::isDead(CallGraphNode *node) const { |
193 | // If the parent operation isn't a symbol, simply check normal SSA deadness. |
194 | Operation *nodeOp = node->getCallableRegion()->getParentOp(); |
195 | if (!isa<SymbolOpInterface>(nodeOp)) |
196 | return isMemoryEffectFree(op: nodeOp) && nodeOp->use_empty(); |
197 | |
198 | // Otherwise, check the number of symbol uses. |
199 | auto symbolIt = discardableSymNodeUses.find(Val: node); |
200 | return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; |
201 | } |
202 | |
203 | bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { |
204 | // If this isn't a symbol node, check for side-effects and SSA use count. |
205 | Operation *nodeOp = node->getCallableRegion()->getParentOp(); |
206 | if (!isa<SymbolOpInterface>(nodeOp)) |
207 | return isMemoryEffectFree(op: nodeOp) && nodeOp->hasOneUse(); |
208 | |
209 | // Otherwise, check the number of symbol uses. |
210 | auto symbolIt = discardableSymNodeUses.find(Val: node); |
211 | return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; |
212 | } |
213 | |
214 | void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { |
215 | Operation *parentOp = node->getCallableRegion()->getParentOp(); |
216 | CGUser &uses = nodeUses[node]; |
217 | decrementDiscardableUses(uses); |
218 | |
219 | // Collect the new discardable uses within this node. |
220 | uses = CGUser(); |
221 | DenseMap<Attribute, CallGraphNode *> resolvedRefs; |
222 | auto walkFn = [&](CallGraphNode *refNode, Operation *user) { |
223 | auto discardSymIt = discardableSymNodeUses.find(Val: refNode); |
224 | if (discardSymIt == discardableSymNodeUses.end()) |
225 | return; |
226 | |
227 | if (user != parentOp) |
228 | ++uses.innerUses[refNode]; |
229 | else if (!uses.topLevelUses.insert(V: refNode).second) |
230 | return; |
231 | ++discardSymIt->second; |
232 | }; |
233 | walkReferencedSymbolNodes(op: parentOp, cg, symbolTable, resolvedRefs, callback: walkFn); |
234 | } |
235 | |
236 | void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { |
237 | auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; |
238 | for (auto &useIt : lhsUses.innerUses) { |
239 | rhsUses.innerUses[useIt.first] += useIt.second; |
240 | discardableSymNodeUses[useIt.first] += useIt.second; |
241 | } |
242 | } |
243 | |
244 | void CGUseList::decrementDiscardableUses(CGUser &uses) { |
245 | for (CallGraphNode *node : uses.topLevelUses) |
246 | --discardableSymNodeUses[node]; |
247 | for (auto &it : uses.innerUses) |
248 | discardableSymNodeUses[it.first] -= it.second; |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // CallGraph traversal |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | namespace { |
256 | /// This class represents a specific callgraph SCC. |
257 | class CallGraphSCC { |
258 | public: |
259 | CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) |
260 | : parentIterator(parentIterator) {} |
261 | /// Return a range over the nodes within this SCC. |
262 | std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } |
263 | std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } |
264 | |
265 | /// Reset the nodes of this SCC with those provided. |
266 | void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } |
267 | |
268 | /// Remove the given node from this SCC. |
269 | void remove(CallGraphNode *node) { |
270 | auto it = llvm::find(Range&: nodes, Val: node); |
271 | if (it != nodes.end()) { |
272 | nodes.erase(position: it); |
273 | parentIterator.ReplaceNode(Old: node, New: nullptr); |
274 | } |
275 | } |
276 | |
277 | private: |
278 | std::vector<CallGraphNode *> nodes; |
279 | llvm::scc_iterator<const CallGraph *> &parentIterator; |
280 | }; |
281 | } // namespace |
282 | |
283 | /// Run a given transformation over the SCCs of the callgraph in a bottom up |
284 | /// traversal. |
285 | static LogicalResult runTransformOnCGSCCs( |
286 | const CallGraph &cg, |
287 | function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { |
288 | llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(G: &cg); |
289 | CallGraphSCC currentSCC(cgi); |
290 | while (!cgi.isAtEnd()) { |
291 | // Copy the current SCC and increment so that the transformer can modify the |
292 | // SCC without invalidating our iterator. |
293 | currentSCC.reset(newNodes: *cgi); |
294 | ++cgi; |
295 | if (failed(result: sccTransformer(currentSCC))) |
296 | return failure(); |
297 | } |
298 | return success(); |
299 | } |
300 | |
301 | /// Collect all of the callable operations within the given range of blocks. If |
302 | /// `traverseNestedCGNodes` is true, this will also collect call operations |
303 | /// inside of nested callgraph nodes. |
304 | static void collectCallOps(iterator_range<Region::iterator> blocks, |
305 | CallGraphNode *sourceNode, CallGraph &cg, |
306 | SymbolTableCollection &symbolTable, |
307 | SmallVectorImpl<ResolvedCall> &calls, |
308 | bool traverseNestedCGNodes) { |
309 | SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; |
310 | auto addToWorklist = [&](CallGraphNode *node, |
311 | iterator_range<Region::iterator> blocks) { |
312 | for (Block &block : blocks) |
313 | worklist.emplace_back(Args: &block, Args&: node); |
314 | }; |
315 | |
316 | addToWorklist(sourceNode, blocks); |
317 | while (!worklist.empty()) { |
318 | Block *block; |
319 | std::tie(args&: block, args&: sourceNode) = worklist.pop_back_val(); |
320 | |
321 | for (Operation &op : *block) { |
322 | if (auto call = dyn_cast<CallOpInterface>(op)) { |
323 | // TODO: Support inlining nested call references. |
324 | CallInterfaceCallable callable = call.getCallableForCallee(); |
325 | if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) { |
326 | if (!isa<FlatSymbolRefAttr>(symRef)) |
327 | continue; |
328 | } |
329 | |
330 | CallGraphNode *targetNode = cg.resolveCallable(call: call, symbolTable); |
331 | if (!targetNode->isExternal()) |
332 | calls.emplace_back(call, sourceNode, targetNode); |
333 | continue; |
334 | } |
335 | |
336 | // If this is not a call, traverse the nested regions. If |
337 | // `traverseNestedCGNodes` is false, then don't traverse nested call graph |
338 | // regions. |
339 | for (auto &nestedRegion : op.getRegions()) { |
340 | CallGraphNode *nestedNode = cg.lookupNode(region: &nestedRegion); |
341 | if (traverseNestedCGNodes || !nestedNode) |
342 | addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); |
343 | } |
344 | } |
345 | } |
346 | } |
347 | |
348 | //===----------------------------------------------------------------------===// |
349 | // InlinerInterfaceImpl |
350 | //===----------------------------------------------------------------------===// |
351 | |
352 | #ifndef NDEBUG |
353 | static std::string getNodeName(CallOpInterface op) { |
354 | if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) |
355 | return debugString(op); |
356 | return "_unnamed_callee_" ; |
357 | } |
358 | #endif |
359 | |
360 | /// Return true if the specified `inlineHistoryID` indicates an inline history |
361 | /// that already includes `node`. |
362 | static bool inlineHistoryIncludes( |
363 | CallGraphNode *node, std::optional<size_t> inlineHistoryID, |
364 | MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>> |
365 | inlineHistory) { |
366 | while (inlineHistoryID.has_value()) { |
367 | assert(*inlineHistoryID < inlineHistory.size() && |
368 | "Invalid inline history ID" ); |
369 | if (inlineHistory[*inlineHistoryID].first == node) |
370 | return true; |
371 | inlineHistoryID = inlineHistory[*inlineHistoryID].second; |
372 | } |
373 | return false; |
374 | } |
375 | |
376 | namespace { |
377 | /// This class provides a specialization of the main inlining interface. |
378 | struct InlinerInterfaceImpl : public InlinerInterface { |
379 | InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg, |
380 | SymbolTableCollection &symbolTable) |
381 | : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} |
382 | |
383 | /// Process a set of blocks that have been inlined. This callback is invoked |
384 | /// *before* inlined terminator operations have been processed. |
385 | void |
386 | processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { |
387 | // Find the closest callgraph node from the first block. |
388 | CallGraphNode *node; |
389 | Region *region = inlinedBlocks.begin()->getParent(); |
390 | while (!(node = cg.lookupNode(region))) { |
391 | region = region->getParentRegion(); |
392 | assert(region && "expected valid parent node" ); |
393 | } |
394 | |
395 | collectCallOps(blocks: inlinedBlocks, sourceNode: node, cg, symbolTable, calls, |
396 | /*traverseNestedCGNodes=*/true); |
397 | } |
398 | |
399 | /// Mark the given callgraph node for deletion. |
400 | void markForDeletion(CallGraphNode *node) { deadNodes.insert(Ptr: node); } |
401 | |
402 | /// This method properly disposes of callables that became dead during |
403 | /// inlining. This should not be called while iterating over the SCCs. |
404 | void eraseDeadCallables() { |
405 | for (CallGraphNode *node : deadNodes) |
406 | node->getCallableRegion()->getParentOp()->erase(); |
407 | } |
408 | |
409 | /// The set of callables known to be dead. |
410 | SmallPtrSet<CallGraphNode *, 8> deadNodes; |
411 | |
412 | /// The current set of call instructions to consider for inlining. |
413 | SmallVector<ResolvedCall, 8> calls; |
414 | |
415 | /// The callgraph being operated on. |
416 | CallGraph &cg; |
417 | |
418 | /// A symbol table to use when resolving call lookups. |
419 | SymbolTableCollection &symbolTable; |
420 | }; |
421 | } // namespace |
422 | |
423 | namespace mlir { |
424 | |
425 | class Inliner::Impl { |
426 | public: |
427 | Impl(Inliner &inliner) : inliner(inliner) {} |
428 | |
429 | /// Attempt to inline calls within the given scc, and run simplifications, |
430 | /// until a fixed point is reached. This allows for the inlining of newly |
431 | /// devirtualized calls. Returns failure if there was a fatal error during |
432 | /// inlining. |
433 | LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, |
434 | CGUseList &useList, CallGraphSCC ¤tSCC, |
435 | MLIRContext *context); |
436 | |
437 | private: |
438 | /// Optimize the nodes within the given SCC with one of the held optimization |
439 | /// pass pipelines. Returns failure if an error occurred during the |
440 | /// optimization of the SCC, success otherwise. |
441 | LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, |
442 | CallGraphSCC ¤tSCC, MLIRContext *context); |
443 | |
444 | /// Optimize the nodes within the given SCC in parallel. Returns failure if an |
445 | /// error occurred during the optimization of the SCC, success otherwise. |
446 | LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, |
447 | MLIRContext *context); |
448 | |
449 | /// Optimize the given callable node with one of the pass managers provided |
450 | /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if |
451 | /// an error occurred during the optimization of the callable, success |
452 | /// otherwise. |
453 | LogicalResult optimizeCallable(CallGraphNode *node, |
454 | llvm::StringMap<OpPassManager> &pipelines); |
455 | |
456 | /// Attempt to inline calls within the given scc. This function returns |
457 | /// success if any calls were inlined, failure otherwise. |
458 | LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, |
459 | CGUseList &useList, CallGraphSCC ¤tSCC); |
460 | |
461 | /// Returns true if the given call should be inlined. |
462 | bool shouldInline(ResolvedCall &resolvedCall); |
463 | |
464 | private: |
465 | Inliner &inliner; |
466 | llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines; |
467 | }; |
468 | |
469 | LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface, |
470 | CGUseList &useList, |
471 | CallGraphSCC ¤tSCC, |
472 | MLIRContext *context) { |
473 | // Continuously simplify and inline until we either reach a fixed point, or |
474 | // hit the maximum iteration count. Simplifying early helps to refine the cost |
475 | // model, and in future iterations may devirtualize new calls. |
476 | unsigned iterationCount = 0; |
477 | do { |
478 | if (failed(result: optimizeSCC(cg&: inlinerIface.cg, useList, currentSCC, context))) |
479 | return failure(); |
480 | if (failed(result: inlineCallsInSCC(inlinerIface, useList, currentSCC))) |
481 | break; |
482 | } while (++iterationCount < inliner.config.getMaxInliningIterations()); |
483 | return success(); |
484 | } |
485 | |
486 | LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList, |
487 | CallGraphSCC ¤tSCC, |
488 | MLIRContext *context) { |
489 | // Collect the sets of nodes to simplify. |
490 | SmallVector<CallGraphNode *, 4> nodesToVisit; |
491 | for (auto *node : currentSCC) { |
492 | if (node->isExternal()) |
493 | continue; |
494 | |
495 | // Don't simplify nodes with children. Nodes with children require special |
496 | // handling as we may remove the node during simplification. In the future, |
497 | // we should be able to handle this case with proper node deletion tracking. |
498 | if (node->hasChildren()) |
499 | continue; |
500 | |
501 | // We also won't apply simplifications to nodes that can't have passes |
502 | // scheduled on them. |
503 | auto *region = node->getCallableRegion(); |
504 | if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
505 | continue; |
506 | nodesToVisit.push_back(Elt: node); |
507 | } |
508 | if (nodesToVisit.empty()) |
509 | return success(); |
510 | |
511 | // Optimize each of the nodes within the SCC in parallel. |
512 | if (failed(result: optimizeSCCAsync(nodesToVisit, context))) |
513 | return failure(); |
514 | |
515 | // Recompute the uses held by each of the nodes. |
516 | for (CallGraphNode *node : nodesToVisit) |
517 | useList.recomputeUses(node, cg); |
518 | return success(); |
519 | } |
520 | |
521 | LogicalResult |
522 | Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, |
523 | MLIRContext *ctx) { |
524 | // We must maintain a fixed pool of pass managers which is at least as large |
525 | // as the maximum parallelism of the failableParallelForEach below. |
526 | // Note: The number of pass managers here needs to remain constant |
527 | // to prevent issues with pass instrumentations that rely on having the same |
528 | // pass manager for the main thread. |
529 | size_t numThreads = ctx->getNumThreads(); |
530 | const auto &opPipelines = inliner.config.getOpPipelines(); |
531 | if (pipelines.size() < numThreads) { |
532 | pipelines.reserve(N: numThreads); |
533 | pipelines.resize(N: numThreads, NV: opPipelines); |
534 | } |
535 | |
536 | // Ensure an analysis manager has been constructed for each of the nodes. |
537 | // This prevents thread races when running the nested pipelines. |
538 | for (CallGraphNode *node : nodesToVisit) |
539 | inliner.am.nest(op: node->getCallableRegion()->getParentOp()); |
540 | |
541 | // An atomic failure variable for the async executors. |
542 | std::vector<std::atomic<bool>> activePMs(pipelines.size()); |
543 | std::fill(first: activePMs.begin(), last: activePMs.end(), value: false); |
544 | return failableParallelForEach(context: ctx, range&: nodesToVisit, func: [&](CallGraphNode *node) { |
545 | // Find a pass manager for this operation. |
546 | auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) { |
547 | bool expectedInactive = false; |
548 | return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true); |
549 | }); |
550 | assert(it != activePMs.end() && |
551 | "could not find inactive pass manager for thread" ); |
552 | unsigned pmIndex = it - activePMs.begin(); |
553 | |
554 | // Optimize this callable node. |
555 | LogicalResult result = optimizeCallable(node, pipelines&: pipelines[pmIndex]); |
556 | |
557 | // Reset the active bit for this pass manager. |
558 | activePMs[pmIndex].store(i: false); |
559 | return result; |
560 | }); |
561 | } |
562 | |
563 | LogicalResult |
564 | Inliner::Impl::optimizeCallable(CallGraphNode *node, |
565 | llvm::StringMap<OpPassManager> &pipelines) { |
566 | Operation *callable = node->getCallableRegion()->getParentOp(); |
567 | StringRef opName = callable->getName().getStringRef(); |
568 | auto pipelineIt = pipelines.find(Key: opName); |
569 | const auto &defaultPipeline = inliner.config.getDefaultPipeline(); |
570 | if (pipelineIt == pipelines.end()) { |
571 | // If a pipeline didn't exist, use the generic pipeline if possible. |
572 | if (!defaultPipeline) |
573 | return success(); |
574 | |
575 | OpPassManager defaultPM(opName); |
576 | defaultPipeline(defaultPM); |
577 | pipelineIt = pipelines.try_emplace(Key: opName, Args: std::move(defaultPM)).first; |
578 | } |
579 | return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable); |
580 | } |
581 | |
582 | /// Attempt to inline calls within the given scc. This function returns |
583 | /// success if any calls were inlined, failure otherwise. |
584 | LogicalResult |
585 | Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, |
586 | CGUseList &useList, CallGraphSCC ¤tSCC) { |
587 | CallGraph &cg = inlinerIface.cg; |
588 | auto &calls = inlinerIface.calls; |
589 | |
590 | // A set of dead nodes to remove after inlining. |
591 | llvm::SmallSetVector<CallGraphNode *, 1> deadNodes; |
592 | |
593 | // Collect all of the direct calls within the nodes of the current SCC. We |
594 | // don't traverse nested callgraph nodes, because they are handled separately |
595 | // likely within a different SCC. |
596 | for (CallGraphNode *node : currentSCC) { |
597 | if (node->isExternal()) |
598 | continue; |
599 | |
600 | // Don't collect calls if the node is already dead. |
601 | if (useList.isDead(node)) { |
602 | deadNodes.insert(X: node); |
603 | } else { |
604 | collectCallOps(blocks: *node->getCallableRegion(), sourceNode: node, cg, |
605 | symbolTable&: inlinerIface.symbolTable, calls, |
606 | /*traverseNestedCGNodes=*/false); |
607 | } |
608 | } |
609 | |
610 | // When inlining a callee produces new call sites, we want to keep track of |
611 | // the fact that they were inlined from the callee. This allows us to avoid |
612 | // infinite inlining. |
613 | using InlineHistoryT = std::optional<size_t>; |
614 | SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory; |
615 | std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); |
616 | |
617 | LLVM_DEBUG({ |
618 | llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n" ; |
619 | for (unsigned i = 0, e = calls.size(); i < e; ++i) |
620 | llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n" ; |
621 | llvm::dbgs() << "}\n" ; |
622 | }); |
623 | |
624 | // Try to inline each of the call operations. Don't cache the end iterator |
625 | // here as more calls may be added during inlining. |
626 | bool inlinedAnyCalls = false; |
627 | for (unsigned i = 0; i < calls.size(); ++i) { |
628 | if (deadNodes.contains(key: calls[i].sourceNode)) |
629 | continue; |
630 | ResolvedCall it = calls[i]; |
631 | |
632 | InlineHistoryT inlineHistoryID = callHistory[i]; |
633 | bool inHistory = |
634 | inlineHistoryIncludes(node: it.targetNode, inlineHistoryID, inlineHistory); |
635 | bool doInline = !inHistory && shouldInline(resolvedCall&: it); |
636 | CallOpInterface call = it.call; |
637 | LLVM_DEBUG({ |
638 | if (doInline) |
639 | llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n" ; |
640 | else |
641 | llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n" ; |
642 | }); |
643 | if (!doInline) |
644 | continue; |
645 | |
646 | unsigned prevSize = calls.size(); |
647 | Region *targetRegion = it.targetNode->getCallableRegion(); |
648 | |
649 | // If this is the last call to the target node and the node is discardable, |
650 | // then inline it in-place and delete the node if successful. |
651 | bool inlineInPlace = useList.hasOneUseAndDiscardable(node: it.targetNode); |
652 | |
653 | LogicalResult inlineResult = |
654 | inlineCall(inlinerIface, call, |
655 | cast<CallableOpInterface>(targetRegion->getParentOp()), |
656 | targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); |
657 | if (failed(result: inlineResult)) { |
658 | LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n" ); |
659 | continue; |
660 | } |
661 | inlinedAnyCalls = true; |
662 | |
663 | // Create a inline history entry for this inlined call, so that we remember |
664 | // that new callsites came about due to inlining Callee. |
665 | InlineHistoryT newInlineHistoryID{inlineHistory.size()}; |
666 | inlineHistory.push_back(Elt: std::make_pair(x&: it.targetNode, y&: inlineHistoryID)); |
667 | |
668 | auto historyToString = [](InlineHistoryT h) { |
669 | return h.has_value() ? std::to_string(val: *h) : "root" ; |
670 | }; |
671 | (void)historyToString; |
672 | LLVM_DEBUG(llvm::dbgs() |
673 | << "* new inlineHistory entry: " << newInlineHistoryID << ". [" |
674 | << getNodeName(call) << ", " << historyToString(inlineHistoryID) |
675 | << "]\n" ); |
676 | |
677 | for (unsigned k = prevSize; k != calls.size(); ++k) { |
678 | callHistory.push_back(x: newInlineHistoryID); |
679 | LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call |
680 | << "}\n with historyID = " << newInlineHistoryID |
681 | << ", added due to inlining of\n call {" << call |
682 | << "}\n with historyID = " |
683 | << historyToString(inlineHistoryID) << "\n" ); |
684 | } |
685 | |
686 | // If the inlining was successful, Merge the new uses into the source node. |
687 | useList.dropCallUses(userNode: it.sourceNode, callOp: call.getOperation(), cg); |
688 | useList.mergeUsesAfterInlining(lhs: it.targetNode, rhs: it.sourceNode); |
689 | |
690 | // then erase the call. |
691 | call.erase(); |
692 | |
693 | // If we inlined in place, mark the node for deletion. |
694 | if (inlineInPlace) { |
695 | useList.eraseNode(node: it.targetNode); |
696 | deadNodes.insert(X: it.targetNode); |
697 | } |
698 | } |
699 | |
700 | for (CallGraphNode *node : deadNodes) { |
701 | currentSCC.remove(node); |
702 | inlinerIface.markForDeletion(node); |
703 | } |
704 | calls.clear(); |
705 | return success(isSuccess: inlinedAnyCalls); |
706 | } |
707 | |
708 | /// Returns true if the given call should be inlined. |
709 | bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { |
710 | // Don't allow inlining terminator calls. We currently don't support this |
711 | // case. |
712 | if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) |
713 | return false; |
714 | |
715 | // Don't allow inlining if the target is a self-recursive function. |
716 | if (llvm::count_if(Range&: *resolvedCall.targetNode, |
717 | P: [&](CallGraphNode::Edge const &edge) -> bool { |
718 | return edge.getTarget() == resolvedCall.targetNode; |
719 | }) > 0) |
720 | return false; |
721 | |
722 | // Don't allow inlining if the target is an ancestor of the call. This |
723 | // prevents inlining recursively. |
724 | Region *callableRegion = resolvedCall.targetNode->getCallableRegion(); |
725 | if (callableRegion->isAncestor(other: resolvedCall.call->getParentRegion())) |
726 | return false; |
727 | |
728 | // Don't allow inlining if the callee has multiple blocks (unstructured |
729 | // control flow) but we cannot be sure that the caller region supports that. |
730 | bool calleeHasMultipleBlocks = |
731 | llvm::hasNItemsOrMore(C&: *callableRegion, /*N=*/2); |
732 | // If both parent ops have the same type, it is safe to inline. Otherwise, |
733 | // decide based on whether the op has the SingleBlock trait or not. |
734 | // Note: This check does currently not account for SizedRegion/MaxSizedRegion. |
735 | auto callerRegionSupportsMultipleBlocks = [&]() { |
736 | return callableRegion->getParentOp()->getName() == |
737 | resolvedCall.call->getParentOp()->getName() || |
738 | !resolvedCall.call->getParentOp() |
739 | ->mightHaveTrait<OpTrait::SingleBlock>(); |
740 | }; |
741 | if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) |
742 | return false; |
743 | |
744 | if (!inliner.isProfitableToInline(resolvedCall)) |
745 | return false; |
746 | |
747 | // Otherwise, inline. |
748 | return true; |
749 | } |
750 | |
751 | LogicalResult Inliner::doInlining() { |
752 | Impl impl(*this); |
753 | auto *context = op->getContext(); |
754 | // Run the inline transform in post-order over the SCCs in the callgraph. |
755 | SymbolTableCollection symbolTable; |
756 | // FIXME: some clean-up can be done for the arguments |
757 | // of the Impl's methods, if the inlinerIface and useList |
758 | // become the states of the Impl. |
759 | InlinerInterfaceImpl inlinerIface(context, cg, symbolTable); |
760 | CGUseList useList(op, cg, symbolTable); |
761 | LogicalResult result = runTransformOnCGSCCs(cg, sccTransformer: [&](CallGraphSCC &scc) { |
762 | return impl.inlineSCC(inlinerIface, useList, currentSCC&: scc, context); |
763 | }); |
764 | if (failed(result)) |
765 | return result; |
766 | |
767 | // After inlining, make sure to erase any callables proven to be dead. |
768 | inlinerIface.eraseDeadCallables(); |
769 | return success(); |
770 | } |
771 | } // namespace mlir |
772 | |