1 | //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Inliner that uses a basic inlining |
10 | // algorithm that operates bottom up over the Strongly Connect Components(SCCs) |
11 | // of the CallGraph. This enables a more incremental propagation of inlining |
12 | // decisions from the leafs to the roots of the callgraph. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Transforms/Inliner.h" |
17 | #include "mlir/IR/Threading.h" |
18 | #include "mlir/Interfaces/CallInterfaces.h" |
19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Support/DebugStringHelper.h" |
22 | #include "mlir/Transforms/InliningUtils.h" |
23 | #include "llvm/ADT/SCCIterator.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/SmallPtrSet.h" |
26 | #include "llvm/Support/Debug.h" |
27 | |
28 | #define DEBUG_TYPE "inlining" |
29 | |
30 | using namespace mlir; |
31 | |
32 | using ResolvedCall = Inliner::ResolvedCall; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // Symbol Use Tracking |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | /// Walk all of the used symbol callgraph nodes referenced with the given op. |
39 | static void walkReferencedSymbolNodes( |
40 | Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, |
41 | DenseMap<Attribute, CallGraphNode *> &resolvedRefs, |
42 | function_ref<void(CallGraphNode *, Operation *)> callback) { |
43 | auto symbolUses = SymbolTable::getSymbolUses(from: op); |
44 | assert(symbolUses && "expected uses to be valid"); |
45 | |
46 | Operation *symbolTableOp = op->getParentOp(); |
47 | for (const SymbolTable::SymbolUse &use : *symbolUses) { |
48 | auto refIt = resolvedRefs.try_emplace(use.getSymbolRef()); |
49 | CallGraphNode *&node = refIt.first->second; |
50 | |
51 | // If this is the first instance of this reference, try to resolve a |
52 | // callgraph node for it. |
53 | if (refIt.second) { |
54 | auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, |
55 | use.getSymbolRef()); |
56 | auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp); |
57 | if (!callableOp) |
58 | continue; |
59 | node = cg.lookupNode(region: callableOp.getCallableRegion()); |
60 | } |
61 | if (node) |
62 | callback(node, use.getUser()); |
63 | } |
64 | } |
65 | |
66 | //===----------------------------------------------------------------------===// |
67 | // CGUseList |
68 | //===----------------------------------------------------------------------===// |
69 | |
70 | namespace { |
71 | /// This struct tracks the uses of callgraph nodes that can be dropped when |
72 | /// use_empty. It directly tracks and manages a use-list for all of the |
73 | /// call-graph nodes. This is necessary because many callgraph nodes are |
74 | /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use` |
75 | /// class. |
76 | struct CGUseList { |
77 | /// This struct tracks the uses of callgraph nodes within a specific |
78 | /// operation. |
79 | struct CGUser { |
80 | /// Any nodes referenced in the top-level attribute list of this user. We |
81 | /// use a set here because the number of references does not matter. |
82 | DenseSet<CallGraphNode *> topLevelUses; |
83 | |
84 | /// Uses of nodes referenced by nested operations. |
85 | DenseMap<CallGraphNode *, int> innerUses; |
86 | }; |
87 | |
88 | CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); |
89 | |
90 | /// Drop uses of nodes referred to by the given call operation that resides |
91 | /// within 'userNode'. |
92 | void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg); |
93 | |
94 | /// Remove the given node from the use list. |
95 | void eraseNode(CallGraphNode *node); |
96 | |
97 | /// Returns true if the given callgraph node has no uses and can be pruned. |
98 | bool isDead(CallGraphNode *node) const; |
99 | |
100 | /// Returns true if the given callgraph node has a single use and can be |
101 | /// discarded. |
102 | bool hasOneUseAndDiscardable(CallGraphNode *node) const; |
103 | |
104 | /// Recompute the uses held by the given callgraph node. |
105 | void recomputeUses(CallGraphNode *node, CallGraph &cg); |
106 | |
107 | /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy |
108 | /// of 'lhs' into 'rhs'. |
109 | void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs); |
110 | |
111 | private: |
112 | /// Decrement the uses of discardable nodes referenced by the given user. |
113 | void decrementDiscardableUses(CGUser &uses); |
114 | |
115 | /// A mapping between a discardable callgraph node (that is a symbol) and the |
116 | /// number of uses for this node. |
117 | DenseMap<CallGraphNode *, int> discardableSymNodeUses; |
118 | |
119 | /// A mapping between a callgraph node and the symbol callgraph nodes that it |
120 | /// uses. |
121 | DenseMap<CallGraphNode *, CGUser> nodeUses; |
122 | |
123 | /// A symbol table to use when resolving call lookups. |
124 | SymbolTableCollection &symbolTable; |
125 | }; |
126 | } // namespace |
127 | |
128 | CGUseList::CGUseList(Operation *op, CallGraph &cg, |
129 | SymbolTableCollection &symbolTable) |
130 | : symbolTable(symbolTable) { |
131 | /// A set of callgraph nodes that are always known to be live during inlining. |
132 | DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes; |
133 | |
134 | // Walk each of the symbol tables looking for discardable callgraph nodes. |
135 | auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { |
136 | for (Operation &op : symbolTableOp->getRegion(index: 0).getOps()) { |
137 | // If this is a callgraph operation, check to see if it is discardable. |
138 | if (auto callable = dyn_cast<CallableOpInterface>(&op)) { |
139 | if (auto *node = cg.lookupNode(callable.getCallableRegion())) { |
140 | SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op); |
141 | if (symbol && (allUsesVisible || symbol.isPrivate()) && |
142 | symbol.canDiscardOnUseEmpty()) { |
143 | discardableSymNodeUses.try_emplace(node, 0); |
144 | } |
145 | continue; |
146 | } |
147 | } |
148 | // Otherwise, check for any referenced nodes. These will be always-live. |
149 | walkReferencedSymbolNodes(op: &op, cg, symbolTable, resolvedRefs&: alwaysLiveNodes, |
150 | callback: [](CallGraphNode *, Operation *) {}); |
151 | } |
152 | }; |
153 | SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), |
154 | callback: walkFn); |
155 | |
156 | // Drop the use information for any discardable nodes that are always live. |
157 | for (auto &it : alwaysLiveNodes) |
158 | discardableSymNodeUses.erase(Val: it.second); |
159 | |
160 | // Compute the uses for each of the callable nodes in the graph. |
161 | for (CallGraphNode *node : cg) |
162 | recomputeUses(node, cg); |
163 | } |
164 | |
165 | void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp, |
166 | CallGraph &cg) { |
167 | auto &userRefs = nodeUses[userNode].innerUses; |
168 | auto walkFn = [&](CallGraphNode *node, Operation *user) { |
169 | auto parentIt = userRefs.find(Val: node); |
170 | if (parentIt == userRefs.end()) |
171 | return; |
172 | --parentIt->second; |
173 | --discardableSymNodeUses[node]; |
174 | }; |
175 | DenseMap<Attribute, CallGraphNode *> resolvedRefs; |
176 | walkReferencedSymbolNodes(op: callOp, cg, symbolTable, resolvedRefs, callback: walkFn); |
177 | } |
178 | |
179 | void CGUseList::eraseNode(CallGraphNode *node) { |
180 | // Drop all child nodes. |
181 | for (auto &edge : *node) |
182 | if (edge.isChild()) |
183 | eraseNode(node: edge.getTarget()); |
184 | |
185 | // Drop the uses held by this node and erase it. |
186 | auto useIt = nodeUses.find(Val: node); |
187 | assert(useIt != nodeUses.end() && "expected node to be valid"); |
188 | decrementDiscardableUses(uses&: useIt->getSecond()); |
189 | nodeUses.erase(I: useIt); |
190 | discardableSymNodeUses.erase(Val: node); |
191 | } |
192 | |
193 | bool CGUseList::isDead(CallGraphNode *node) const { |
194 | // If the parent operation isn't a symbol, simply check normal SSA deadness. |
195 | Operation *nodeOp = node->getCallableRegion()->getParentOp(); |
196 | if (!isa<SymbolOpInterface>(nodeOp)) |
197 | return isMemoryEffectFree(op: nodeOp) && nodeOp->use_empty(); |
198 | |
199 | // Otherwise, check the number of symbol uses. |
200 | auto symbolIt = discardableSymNodeUses.find(Val: node); |
201 | return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0; |
202 | } |
203 | |
204 | bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { |
205 | // If this isn't a symbol node, check for side-effects and SSA use count. |
206 | Operation *nodeOp = node->getCallableRegion()->getParentOp(); |
207 | if (!isa<SymbolOpInterface>(nodeOp)) |
208 | return isMemoryEffectFree(op: nodeOp) && nodeOp->hasOneUse(); |
209 | |
210 | // Otherwise, check the number of symbol uses. |
211 | auto symbolIt = discardableSymNodeUses.find(Val: node); |
212 | return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1; |
213 | } |
214 | |
215 | void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) { |
216 | Operation *parentOp = node->getCallableRegion()->getParentOp(); |
217 | CGUser &uses = nodeUses[node]; |
218 | decrementDiscardableUses(uses); |
219 | |
220 | // Collect the new discardable uses within this node. |
221 | uses = CGUser(); |
222 | DenseMap<Attribute, CallGraphNode *> resolvedRefs; |
223 | auto walkFn = [&](CallGraphNode *refNode, Operation *user) { |
224 | auto discardSymIt = discardableSymNodeUses.find(Val: refNode); |
225 | if (discardSymIt == discardableSymNodeUses.end()) |
226 | return; |
227 | |
228 | if (user != parentOp) |
229 | ++uses.innerUses[refNode]; |
230 | else if (!uses.topLevelUses.insert(V: refNode).second) |
231 | return; |
232 | ++discardSymIt->second; |
233 | }; |
234 | walkReferencedSymbolNodes(op: parentOp, cg, symbolTable, resolvedRefs, callback: walkFn); |
235 | } |
236 | |
237 | void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { |
238 | auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs]; |
239 | for (auto &useIt : lhsUses.innerUses) { |
240 | rhsUses.innerUses[useIt.first] += useIt.second; |
241 | discardableSymNodeUses[useIt.first] += useIt.second; |
242 | } |
243 | } |
244 | |
245 | void CGUseList::decrementDiscardableUses(CGUser &uses) { |
246 | for (CallGraphNode *node : uses.topLevelUses) |
247 | --discardableSymNodeUses[node]; |
248 | for (auto &it : uses.innerUses) |
249 | discardableSymNodeUses[it.first] -= it.second; |
250 | } |
251 | |
252 | //===----------------------------------------------------------------------===// |
253 | // CallGraph traversal |
254 | //===----------------------------------------------------------------------===// |
255 | |
256 | namespace { |
257 | /// This class represents a specific callgraph SCC. |
258 | class CallGraphSCC { |
259 | public: |
260 | CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator) |
261 | : parentIterator(parentIterator) {} |
262 | /// Return a range over the nodes within this SCC. |
263 | std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); } |
264 | std::vector<CallGraphNode *>::iterator end() { return nodes.end(); } |
265 | |
266 | /// Reset the nodes of this SCC with those provided. |
267 | void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; } |
268 | |
269 | /// Remove the given node from this SCC. |
270 | void remove(CallGraphNode *node) { |
271 | auto it = llvm::find(Range&: nodes, Val: node); |
272 | if (it != nodes.end()) { |
273 | nodes.erase(position: it); |
274 | parentIterator.ReplaceNode(Old: node, New: nullptr); |
275 | } |
276 | } |
277 | |
278 | private: |
279 | std::vector<CallGraphNode *> nodes; |
280 | llvm::scc_iterator<const CallGraph *> &parentIterator; |
281 | }; |
282 | } // namespace |
283 | |
284 | /// Run a given transformation over the SCCs of the callgraph in a bottom up |
285 | /// traversal. |
286 | static LogicalResult runTransformOnCGSCCs( |
287 | const CallGraph &cg, |
288 | function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) { |
289 | llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(G: &cg); |
290 | CallGraphSCC currentSCC(cgi); |
291 | while (!cgi.isAtEnd()) { |
292 | // Copy the current SCC and increment so that the transformer can modify the |
293 | // SCC without invalidating our iterator. |
294 | currentSCC.reset(newNodes: *cgi); |
295 | ++cgi; |
296 | if (failed(Result: sccTransformer(currentSCC))) |
297 | return failure(); |
298 | } |
299 | return success(); |
300 | } |
301 | |
302 | /// Collect all of the callable operations within the given range of blocks. If |
303 | /// `traverseNestedCGNodes` is true, this will also collect call operations |
304 | /// inside of nested callgraph nodes. |
305 | static void collectCallOps(iterator_range<Region::iterator> blocks, |
306 | CallGraphNode *sourceNode, CallGraph &cg, |
307 | SymbolTableCollection &symbolTable, |
308 | SmallVectorImpl<ResolvedCall> &calls, |
309 | bool traverseNestedCGNodes) { |
310 | SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist; |
311 | auto addToWorklist = [&](CallGraphNode *node, |
312 | iterator_range<Region::iterator> blocks) { |
313 | for (Block &block : blocks) |
314 | worklist.emplace_back(Args: &block, Args&: node); |
315 | }; |
316 | |
317 | addToWorklist(sourceNode, blocks); |
318 | while (!worklist.empty()) { |
319 | Block *block; |
320 | std::tie(args&: block, args&: sourceNode) = worklist.pop_back_val(); |
321 | |
322 | for (Operation &op : *block) { |
323 | if (auto call = dyn_cast<CallOpInterface>(op)) { |
324 | // TODO: Support inlining nested call references. |
325 | CallInterfaceCallable callable = call.getCallableForCallee(); |
326 | if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) { |
327 | if (!isa<FlatSymbolRefAttr>(symRef)) |
328 | continue; |
329 | } |
330 | |
331 | CallGraphNode *targetNode = cg.resolveCallable(call: call, symbolTable); |
332 | if (!targetNode->isExternal()) |
333 | calls.emplace_back(call, sourceNode, targetNode); |
334 | continue; |
335 | } |
336 | |
337 | // If this is not a call, traverse the nested regions. If |
338 | // `traverseNestedCGNodes` is false, then don't traverse nested call graph |
339 | // regions. |
340 | for (auto &nestedRegion : op.getRegions()) { |
341 | CallGraphNode *nestedNode = cg.lookupNode(region: &nestedRegion); |
342 | if (traverseNestedCGNodes || !nestedNode) |
343 | addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion); |
344 | } |
345 | } |
346 | } |
347 | } |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // InlinerInterfaceImpl |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | #ifndef NDEBUG |
354 | static std::string getNodeName(CallOpInterface op) { |
355 | if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) |
356 | return debugString(op); |
357 | return "_unnamed_callee_"; |
358 | } |
359 | #endif |
360 | |
361 | /// Return true if the specified `inlineHistoryID` indicates an inline history |
362 | /// that already includes `node`. |
363 | static bool inlineHistoryIncludes( |
364 | CallGraphNode *node, std::optional<size_t> inlineHistoryID, |
365 | MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>> |
366 | inlineHistory) { |
367 | while (inlineHistoryID.has_value()) { |
368 | assert(*inlineHistoryID < inlineHistory.size() && |
369 | "Invalid inline history ID"); |
370 | if (inlineHistory[*inlineHistoryID].first == node) |
371 | return true; |
372 | inlineHistoryID = inlineHistory[*inlineHistoryID].second; |
373 | } |
374 | return false; |
375 | } |
376 | |
377 | namespace { |
378 | /// This class provides a specialization of the main inlining interface. |
379 | struct InlinerInterfaceImpl : public InlinerInterface { |
380 | InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg, |
381 | SymbolTableCollection &symbolTable) |
382 | : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} |
383 | |
384 | /// Process a set of blocks that have been inlined. This callback is invoked |
385 | /// *before* inlined terminator operations have been processed. |
386 | void |
387 | processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { |
388 | // Find the closest callgraph node from the first block. |
389 | CallGraphNode *node; |
390 | Region *region = inlinedBlocks.begin()->getParent(); |
391 | while (!(node = cg.lookupNode(region))) { |
392 | region = region->getParentRegion(); |
393 | assert(region && "expected valid parent node"); |
394 | } |
395 | |
396 | collectCallOps(blocks: inlinedBlocks, sourceNode: node, cg, symbolTable, calls, |
397 | /*traverseNestedCGNodes=*/true); |
398 | } |
399 | |
400 | /// Mark the given callgraph node for deletion. |
401 | void markForDeletion(CallGraphNode *node) { deadNodes.insert(Ptr: node); } |
402 | |
403 | /// This method properly disposes of callables that became dead during |
404 | /// inlining. This should not be called while iterating over the SCCs. |
405 | void eraseDeadCallables() { |
406 | for (CallGraphNode *node : deadNodes) |
407 | node->getCallableRegion()->getParentOp()->erase(); |
408 | } |
409 | |
410 | /// The set of callables known to be dead. |
411 | SmallPtrSet<CallGraphNode *, 8> deadNodes; |
412 | |
413 | /// The current set of call instructions to consider for inlining. |
414 | SmallVector<ResolvedCall, 8> calls; |
415 | |
416 | /// The callgraph being operated on. |
417 | CallGraph &cg; |
418 | |
419 | /// A symbol table to use when resolving call lookups. |
420 | SymbolTableCollection &symbolTable; |
421 | }; |
422 | } // namespace |
423 | |
424 | namespace mlir { |
425 | |
426 | class Inliner::Impl { |
427 | public: |
428 | Impl(Inliner &inliner) : inliner(inliner) {} |
429 | |
430 | /// Attempt to inline calls within the given scc, and run simplifications, |
431 | /// until a fixed point is reached. This allows for the inlining of newly |
432 | /// devirtualized calls. Returns failure if there was a fatal error during |
433 | /// inlining. |
434 | LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, |
435 | CGUseList &useList, CallGraphSCC ¤tSCC, |
436 | MLIRContext *context); |
437 | |
438 | private: |
439 | /// Optimize the nodes within the given SCC with one of the held optimization |
440 | /// pass pipelines. Returns failure if an error occurred during the |
441 | /// optimization of the SCC, success otherwise. |
442 | LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, |
443 | CallGraphSCC ¤tSCC, MLIRContext *context); |
444 | |
445 | /// Optimize the nodes within the given SCC in parallel. Returns failure if an |
446 | /// error occurred during the optimization of the SCC, success otherwise. |
447 | LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, |
448 | MLIRContext *context); |
449 | |
450 | /// Optimize the given callable node with one of the pass managers provided |
451 | /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if |
452 | /// an error occurred during the optimization of the callable, success |
453 | /// otherwise. |
454 | LogicalResult optimizeCallable(CallGraphNode *node, |
455 | llvm::StringMap<OpPassManager> &pipelines); |
456 | |
457 | /// Attempt to inline calls within the given scc. This function returns |
458 | /// success if any calls were inlined, failure otherwise. |
459 | LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, |
460 | CGUseList &useList, CallGraphSCC ¤tSCC); |
461 | |
462 | /// Returns true if the given call should be inlined. |
463 | bool shouldInline(ResolvedCall &resolvedCall); |
464 | |
465 | private: |
466 | Inliner &inliner; |
467 | llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines; |
468 | }; |
469 | |
470 | LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface, |
471 | CGUseList &useList, |
472 | CallGraphSCC ¤tSCC, |
473 | MLIRContext *context) { |
474 | // Continuously simplify and inline until we either reach a fixed point, or |
475 | // hit the maximum iteration count. Simplifying early helps to refine the cost |
476 | // model, and in future iterations may devirtualize new calls. |
477 | unsigned iterationCount = 0; |
478 | do { |
479 | if (failed(Result: optimizeSCC(cg&: inlinerIface.cg, useList, currentSCC, context))) |
480 | return failure(); |
481 | if (failed(Result: inlineCallsInSCC(inlinerIface, useList, currentSCC))) |
482 | break; |
483 | } while (++iterationCount < inliner.config.getMaxInliningIterations()); |
484 | return success(); |
485 | } |
486 | |
487 | LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList, |
488 | CallGraphSCC ¤tSCC, |
489 | MLIRContext *context) { |
490 | // Collect the sets of nodes to simplify. |
491 | SmallVector<CallGraphNode *, 4> nodesToVisit; |
492 | for (auto *node : currentSCC) { |
493 | if (node->isExternal()) |
494 | continue; |
495 | |
496 | // Don't simplify nodes with children. Nodes with children require special |
497 | // handling as we may remove the node during simplification. In the future, |
498 | // we should be able to handle this case with proper node deletion tracking. |
499 | if (node->hasChildren()) |
500 | continue; |
501 | |
502 | // We also won't apply simplifications to nodes that can't have passes |
503 | // scheduled on them. |
504 | auto *region = node->getCallableRegion(); |
505 | if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
506 | continue; |
507 | nodesToVisit.push_back(Elt: node); |
508 | } |
509 | if (nodesToVisit.empty()) |
510 | return success(); |
511 | |
512 | // Optimize each of the nodes within the SCC in parallel. |
513 | if (failed(Result: optimizeSCCAsync(nodesToVisit, context))) |
514 | return failure(); |
515 | |
516 | // Recompute the uses held by each of the nodes. |
517 | for (CallGraphNode *node : nodesToVisit) |
518 | useList.recomputeUses(node, cg); |
519 | return success(); |
520 | } |
521 | |
522 | LogicalResult |
523 | Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit, |
524 | MLIRContext *ctx) { |
525 | // We must maintain a fixed pool of pass managers which is at least as large |
526 | // as the maximum parallelism of the failableParallelForEach below. |
527 | // Note: The number of pass managers here needs to remain constant |
528 | // to prevent issues with pass instrumentations that rely on having the same |
529 | // pass manager for the main thread. |
530 | size_t numThreads = ctx->getNumThreads(); |
531 | const auto &opPipelines = inliner.config.getOpPipelines(); |
532 | if (pipelines.size() < numThreads) { |
533 | pipelines.reserve(N: numThreads); |
534 | pipelines.resize(N: numThreads, NV: opPipelines); |
535 | } |
536 | |
537 | // Ensure an analysis manager has been constructed for each of the nodes. |
538 | // This prevents thread races when running the nested pipelines. |
539 | for (CallGraphNode *node : nodesToVisit) |
540 | inliner.am.nest(op: node->getCallableRegion()->getParentOp()); |
541 | |
542 | // An atomic failure variable for the async executors. |
543 | std::vector<std::atomic<bool>> activePMs(pipelines.size()); |
544 | std::fill(first: activePMs.begin(), last: activePMs.end(), value: false); |
545 | return failableParallelForEach(context: ctx, range&: nodesToVisit, func: [&](CallGraphNode *node) { |
546 | // Find a pass manager for this operation. |
547 | auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) { |
548 | bool expectedInactive = false; |
549 | return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true); |
550 | }); |
551 | assert(it != activePMs.end() && |
552 | "could not find inactive pass manager for thread"); |
553 | unsigned pmIndex = it - activePMs.begin(); |
554 | |
555 | // Optimize this callable node. |
556 | LogicalResult result = optimizeCallable(node, pipelines&: pipelines[pmIndex]); |
557 | |
558 | // Reset the active bit for this pass manager. |
559 | activePMs[pmIndex].store(i: false); |
560 | return result; |
561 | }); |
562 | } |
563 | |
564 | LogicalResult |
565 | Inliner::Impl::optimizeCallable(CallGraphNode *node, |
566 | llvm::StringMap<OpPassManager> &pipelines) { |
567 | Operation *callable = node->getCallableRegion()->getParentOp(); |
568 | StringRef opName = callable->getName().getStringRef(); |
569 | auto pipelineIt = pipelines.find(Key: opName); |
570 | const auto &defaultPipeline = inliner.config.getDefaultPipeline(); |
571 | if (pipelineIt == pipelines.end()) { |
572 | // If a pipeline didn't exist, use the generic pipeline if possible. |
573 | if (!defaultPipeline) |
574 | return success(); |
575 | |
576 | OpPassManager defaultPM(opName); |
577 | defaultPipeline(defaultPM); |
578 | pipelineIt = pipelines.try_emplace(Key: opName, Args: std::move(defaultPM)).first; |
579 | } |
580 | return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable); |
581 | } |
582 | |
583 | /// Attempt to inline calls within the given scc. This function returns |
584 | /// success if any calls were inlined, failure otherwise. |
585 | LogicalResult |
586 | Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, |
587 | CGUseList &useList, CallGraphSCC ¤tSCC) { |
588 | CallGraph &cg = inlinerIface.cg; |
589 | auto &calls = inlinerIface.calls; |
590 | |
591 | // A set of dead nodes to remove after inlining. |
592 | llvm::SmallSetVector<CallGraphNode *, 1> deadNodes; |
593 | |
594 | // Collect all of the direct calls within the nodes of the current SCC. We |
595 | // don't traverse nested callgraph nodes, because they are handled separately |
596 | // likely within a different SCC. |
597 | for (CallGraphNode *node : currentSCC) { |
598 | if (node->isExternal()) |
599 | continue; |
600 | |
601 | // Don't collect calls if the node is already dead. |
602 | if (useList.isDead(node)) { |
603 | deadNodes.insert(X: node); |
604 | } else { |
605 | collectCallOps(blocks: *node->getCallableRegion(), sourceNode: node, cg, |
606 | symbolTable&: inlinerIface.symbolTable, calls, |
607 | /*traverseNestedCGNodes=*/false); |
608 | } |
609 | } |
610 | |
611 | // When inlining a callee produces new call sites, we want to keep track of |
612 | // the fact that they were inlined from the callee. This allows us to avoid |
613 | // infinite inlining. |
614 | using InlineHistoryT = std::optional<size_t>; |
615 | SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory; |
616 | std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); |
617 | |
618 | LLVM_DEBUG({ |
619 | llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; |
620 | for (unsigned i = 0, e = calls.size(); i < e; ++i) |
621 | llvm::dbgs() << " "<< i << ". "<< calls[i].call << ",\n"; |
622 | llvm::dbgs() << "}\n"; |
623 | }); |
624 | |
625 | // Try to inline each of the call operations. Don't cache the end iterator |
626 | // here as more calls may be added during inlining. |
627 | bool inlinedAnyCalls = false; |
628 | for (unsigned i = 0; i < calls.size(); ++i) { |
629 | if (deadNodes.contains(key: calls[i].sourceNode)) |
630 | continue; |
631 | ResolvedCall it = calls[i]; |
632 | |
633 | InlineHistoryT inlineHistoryID = callHistory[i]; |
634 | bool inHistory = |
635 | inlineHistoryIncludes(node: it.targetNode, inlineHistoryID, inlineHistory); |
636 | bool doInline = !inHistory && shouldInline(resolvedCall&: it); |
637 | CallOpInterface call = it.call; |
638 | LLVM_DEBUG({ |
639 | if (doInline) |
640 | llvm::dbgs() << "* Inlining call: "<< i << ". "<< call << "\n"; |
641 | else |
642 | llvm::dbgs() << "* Not inlining call: "<< i << ". "<< call << "\n"; |
643 | }); |
644 | if (!doInline) |
645 | continue; |
646 | |
647 | unsigned prevSize = calls.size(); |
648 | Region *targetRegion = it.targetNode->getCallableRegion(); |
649 | |
650 | // If this is the last call to the target node and the node is discardable, |
651 | // then inline it in-place and delete the node if successful. |
652 | bool inlineInPlace = useList.hasOneUseAndDiscardable(node: it.targetNode); |
653 | |
654 | LogicalResult inlineResult = |
655 | inlineCall(inlinerIface, inliner.config.getCloneCallback(), call, |
656 | cast<CallableOpInterface>(targetRegion->getParentOp()), |
657 | targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); |
658 | if (failed(Result: inlineResult)) { |
659 | LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); |
660 | continue; |
661 | } |
662 | inlinedAnyCalls = true; |
663 | |
664 | // Create a inline history entry for this inlined call, so that we remember |
665 | // that new callsites came about due to inlining Callee. |
666 | InlineHistoryT newInlineHistoryID{inlineHistory.size()}; |
667 | inlineHistory.push_back(Elt: std::make_pair(x&: it.targetNode, y&: inlineHistoryID)); |
668 | |
669 | auto historyToString = [](InlineHistoryT h) { |
670 | return h.has_value() ? std::to_string(val: *h) : "root"; |
671 | }; |
672 | (void)historyToString; |
673 | LLVM_DEBUG(llvm::dbgs() |
674 | << "* new inlineHistory entry: "<< newInlineHistoryID << ". [" |
675 | << getNodeName(call) << ", "<< historyToString(inlineHistoryID) |
676 | << "]\n"); |
677 | |
678 | for (unsigned k = prevSize; k != calls.size(); ++k) { |
679 | callHistory.push_back(x: newInlineHistoryID); |
680 | LLVM_DEBUG(llvm::dbgs() << "* new call "<< k << " {"<< calls[i].call |
681 | << "}\n with historyID = "<< newInlineHistoryID |
682 | << ", added due to inlining of\n call {"<< call |
683 | << "}\n with historyID = " |
684 | << historyToString(inlineHistoryID) << "\n"); |
685 | } |
686 | |
687 | // If the inlining was successful, Merge the new uses into the source node. |
688 | useList.dropCallUses(userNode: it.sourceNode, callOp: call.getOperation(), cg); |
689 | useList.mergeUsesAfterInlining(lhs: it.targetNode, rhs: it.sourceNode); |
690 | |
691 | // then erase the call. |
692 | call.erase(); |
693 | |
694 | // If we inlined in place, mark the node for deletion. |
695 | if (inlineInPlace) { |
696 | useList.eraseNode(node: it.targetNode); |
697 | deadNodes.insert(X: it.targetNode); |
698 | } |
699 | } |
700 | |
701 | for (CallGraphNode *node : deadNodes) { |
702 | currentSCC.remove(node); |
703 | inlinerIface.markForDeletion(node); |
704 | } |
705 | calls.clear(); |
706 | return success(IsSuccess: inlinedAnyCalls); |
707 | } |
708 | |
709 | /// Returns true if the given call should be inlined. |
710 | bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { |
711 | // Don't allow inlining terminator calls. We currently don't support this |
712 | // case. |
713 | if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>()) |
714 | return false; |
715 | |
716 | // Don't allow inlining if the target is a self-recursive function. |
717 | // Don't allow inlining if the call graph is like A->B->A. |
718 | if (llvm::count_if(Range&: *resolvedCall.targetNode, |
719 | P: [&](CallGraphNode::Edge const &edge) -> bool { |
720 | return edge.getTarget() == resolvedCall.targetNode || |
721 | edge.getTarget() == resolvedCall.sourceNode; |
722 | }) > 0) |
723 | return false; |
724 | |
725 | // Don't allow inlining if the target is an ancestor of the call. This |
726 | // prevents inlining recursively. |
727 | Region *callableRegion = resolvedCall.targetNode->getCallableRegion(); |
728 | if (callableRegion->isAncestor(other: resolvedCall.call->getParentRegion())) |
729 | return false; |
730 | |
731 | // Don't allow inlining if the callee has multiple blocks (unstructured |
732 | // control flow) but we cannot be sure that the caller region supports that. |
733 | if (!inliner.config.getCanHandleMultipleBlocks()) { |
734 | bool calleeHasMultipleBlocks = |
735 | llvm::hasNItemsOrMore(C&: *callableRegion, /*N=*/2); |
736 | // If both parent ops have the same type, it is safe to inline. Otherwise, |
737 | // decide based on whether the op has the SingleBlock trait or not. |
738 | // Note: This check does currently not account for |
739 | // SizedRegion/MaxSizedRegion. |
740 | auto callerRegionSupportsMultipleBlocks = [&]() { |
741 | return callableRegion->getParentOp()->getName() == |
742 | resolvedCall.call->getParentOp()->getName() || |
743 | !resolvedCall.call->getParentOp() |
744 | ->mightHaveTrait<OpTrait::SingleBlock>(); |
745 | }; |
746 | if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks()) |
747 | return false; |
748 | } |
749 | |
750 | if (!inliner.isProfitableToInline(resolvedCall)) |
751 | return false; |
752 | |
753 | // Otherwise, inline. |
754 | return true; |
755 | } |
756 | |
757 | LogicalResult Inliner::doInlining() { |
758 | Impl impl(*this); |
759 | auto *context = op->getContext(); |
760 | // Run the inline transform in post-order over the SCCs in the callgraph. |
761 | SymbolTableCollection symbolTable; |
762 | // FIXME: some clean-up can be done for the arguments |
763 | // of the Impl's methods, if the inlinerIface and useList |
764 | // become the states of the Impl. |
765 | InlinerInterfaceImpl inlinerIface(context, cg, symbolTable); |
766 | CGUseList useList(op, cg, symbolTable); |
767 | LogicalResult result = runTransformOnCGSCCs(cg, sccTransformer: [&](CallGraphSCC &scc) { |
768 | return impl.inlineSCC(inlinerIface, useList, currentSCC&: scc, context); |
769 | }); |
770 | if (failed(Result: result)) |
771 | return result; |
772 | |
773 | // After inlining, make sure to erase any callables proven to be dead. |
774 | inlinerIface.eraseDeadCallables(); |
775 | return success(); |
776 | } |
777 | } // namespace mlir |
778 |
Definitions
- walkReferencedSymbolNodes
- CGUseList
- CGUser
- CGUseList
- dropCallUses
- eraseNode
- isDead
- hasOneUseAndDiscardable
- recomputeUses
- mergeUsesAfterInlining
- decrementDiscardableUses
- CallGraphSCC
- CallGraphSCC
- begin
- end
- reset
- remove
- runTransformOnCGSCCs
- collectCallOps
- getNodeName
- inlineHistoryIncludes
- InlinerInterfaceImpl
- InlinerInterfaceImpl
- processInlinedBlocks
- markForDeletion
- eraseDeadCallables
- Impl
- Impl
- inlineSCC
- optimizeSCC
- optimizeSCCAsync
- optimizeCallable
- inlineCallsInSCC
- shouldInline
Learn to use CMake with our Intro Training
Find out more