1//===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Inliner that uses a basic inlining
10// algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11// of the CallGraph. This enables a more incremental propagation of inlining
12// decisions from the leafs to the roots of the callgraph.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Transforms/Inliner.h"
17#include "mlir/IR/Threading.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Support/DebugStringHelper.h"
21#include "mlir/Transforms/InliningUtils.h"
22#include "llvm/ADT/SCCIterator.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "inlining"
27
28using namespace mlir;
29
30using ResolvedCall = Inliner::ResolvedCall;
31
32//===----------------------------------------------------------------------===//
33// Symbol Use Tracking
34//===----------------------------------------------------------------------===//
35
36/// Walk all of the used symbol callgraph nodes referenced with the given op.
37static void walkReferencedSymbolNodes(
38 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
39 DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
40 function_ref<void(CallGraphNode *, Operation *)> callback) {
41 auto symbolUses = SymbolTable::getSymbolUses(from: op);
42 assert(symbolUses && "expected uses to be valid");
43
44 Operation *symbolTableOp = op->getParentOp();
45 for (const SymbolTable::SymbolUse &use : *symbolUses) {
46 auto refIt = resolvedRefs.try_emplace(Key: use.getSymbolRef());
47 CallGraphNode *&node = refIt.first->second;
48
49 // If this is the first instance of this reference, try to resolve a
50 // callgraph node for it.
51 if (refIt.second) {
52 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(from: symbolTableOp,
53 symbol: use.getSymbolRef());
54 auto callableOp = dyn_cast_or_null<CallableOpInterface>(Val: symbolOp);
55 if (!callableOp)
56 continue;
57 node = cg.lookupNode(region: callableOp.getCallableRegion());
58 }
59 if (node)
60 callback(node, use.getUser());
61 }
62}
63
64//===----------------------------------------------------------------------===//
65// CGUseList
66//===----------------------------------------------------------------------===//
67
68namespace {
69/// This struct tracks the uses of callgraph nodes that can be dropped when
70/// use_empty. It directly tracks and manages a use-list for all of the
71/// call-graph nodes. This is necessary because many callgraph nodes are
72/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
73/// class.
74struct CGUseList {
75 /// This struct tracks the uses of callgraph nodes within a specific
76 /// operation.
77 struct CGUser {
78 /// Any nodes referenced in the top-level attribute list of this user. We
79 /// use a set here because the number of references does not matter.
80 DenseSet<CallGraphNode *> topLevelUses;
81
82 /// Uses of nodes referenced by nested operations.
83 DenseMap<CallGraphNode *, int> innerUses;
84 };
85
86 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
87
88 /// Drop uses of nodes referred to by the given call operation that resides
89 /// within 'userNode'.
90 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
91
92 /// Remove the given node from the use list.
93 void eraseNode(CallGraphNode *node);
94
95 /// Returns true if the given callgraph node has no uses and can be pruned.
96 bool isDead(CallGraphNode *node) const;
97
98 /// Returns true if the given callgraph node has a single use and can be
99 /// discarded.
100 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
101
102 /// Recompute the uses held by the given callgraph node.
103 void recomputeUses(CallGraphNode *node, CallGraph &cg);
104
105 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
106 /// of 'lhs' into 'rhs'.
107 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
108
109private:
110 /// Decrement the uses of discardable nodes referenced by the given user.
111 void decrementDiscardableUses(CGUser &uses);
112
113 /// A mapping between a discardable callgraph node (that is a symbol) and the
114 /// number of uses for this node.
115 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
116
117 /// A mapping between a callgraph node and the symbol callgraph nodes that it
118 /// uses.
119 DenseMap<CallGraphNode *, CGUser> nodeUses;
120
121 /// A symbol table to use when resolving call lookups.
122 SymbolTableCollection &symbolTable;
123};
124} // namespace
125
126CGUseList::CGUseList(Operation *op, CallGraph &cg,
127 SymbolTableCollection &symbolTable)
128 : symbolTable(symbolTable) {
129 /// A set of callgraph nodes that are always known to be live during inlining.
130 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
131
132 // Walk each of the symbol tables looking for discardable callgraph nodes.
133 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
134 for (Operation &op : symbolTableOp->getRegion(index: 0).getOps()) {
135 // If this is a callgraph operation, check to see if it is discardable.
136 if (auto callable = dyn_cast<CallableOpInterface>(Val: &op)) {
137 if (auto *node = cg.lookupNode(region: callable.getCallableRegion())) {
138 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(Val: &op);
139 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
140 symbol.canDiscardOnUseEmpty()) {
141 discardableSymNodeUses.try_emplace(Key: node, Args: 0);
142 }
143 continue;
144 }
145 }
146 // Otherwise, check for any referenced nodes. These will be always-live.
147 walkReferencedSymbolNodes(op: &op, cg, symbolTable, resolvedRefs&: alwaysLiveNodes,
148 callback: [](CallGraphNode *, Operation *) {});
149 }
150 };
151 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
152 callback: walkFn);
153
154 // Drop the use information for any discardable nodes that are always live.
155 for (auto &it : alwaysLiveNodes)
156 discardableSymNodeUses.erase(Val: it.second);
157
158 // Compute the uses for each of the callable nodes in the graph.
159 for (CallGraphNode *node : cg)
160 recomputeUses(node, cg);
161}
162
163void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
164 CallGraph &cg) {
165 auto &userRefs = nodeUses[userNode].innerUses;
166 auto walkFn = [&](CallGraphNode *node, Operation *user) {
167 auto parentIt = userRefs.find(Val: node);
168 if (parentIt == userRefs.end())
169 return;
170 --parentIt->second;
171 --discardableSymNodeUses[node];
172 };
173 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
174 walkReferencedSymbolNodes(op: callOp, cg, symbolTable, resolvedRefs, callback: walkFn);
175}
176
177void CGUseList::eraseNode(CallGraphNode *node) {
178 // Drop all child nodes.
179 for (auto &edge : *node)
180 if (edge.isChild())
181 eraseNode(node: edge.getTarget());
182
183 // Drop the uses held by this node and erase it.
184 auto useIt = nodeUses.find(Val: node);
185 assert(useIt != nodeUses.end() && "expected node to be valid");
186 decrementDiscardableUses(uses&: useIt->getSecond());
187 nodeUses.erase(I: useIt);
188 discardableSymNodeUses.erase(Val: node);
189}
190
191bool CGUseList::isDead(CallGraphNode *node) const {
192 // If the parent operation isn't a symbol, simply check normal SSA deadness.
193 Operation *nodeOp = node->getCallableRegion()->getParentOp();
194 if (!isa<SymbolOpInterface>(Val: nodeOp))
195 return isMemoryEffectFree(op: nodeOp) && nodeOp->use_empty();
196
197 // Otherwise, check the number of symbol uses.
198 auto symbolIt = discardableSymNodeUses.find(Val: node);
199 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
200}
201
202bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
203 // If this isn't a symbol node, check for side-effects and SSA use count.
204 Operation *nodeOp = node->getCallableRegion()->getParentOp();
205 if (!isa<SymbolOpInterface>(Val: nodeOp))
206 return isMemoryEffectFree(op: nodeOp) && nodeOp->hasOneUse();
207
208 // Otherwise, check the number of symbol uses.
209 auto symbolIt = discardableSymNodeUses.find(Val: node);
210 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
211}
212
213void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
214 Operation *parentOp = node->getCallableRegion()->getParentOp();
215 CGUser &uses = nodeUses[node];
216 decrementDiscardableUses(uses);
217
218 // Collect the new discardable uses within this node.
219 uses = CGUser();
220 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
221 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
222 auto discardSymIt = discardableSymNodeUses.find(Val: refNode);
223 if (discardSymIt == discardableSymNodeUses.end())
224 return;
225
226 if (user != parentOp)
227 ++uses.innerUses[refNode];
228 else if (!uses.topLevelUses.insert(V: refNode).second)
229 return;
230 ++discardSymIt->second;
231 };
232 walkReferencedSymbolNodes(op: parentOp, cg, symbolTable, resolvedRefs, callback: walkFn);
233}
234
235void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
236 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
237 for (auto &useIt : lhsUses.innerUses) {
238 rhsUses.innerUses[useIt.first] += useIt.second;
239 discardableSymNodeUses[useIt.first] += useIt.second;
240 }
241}
242
243void CGUseList::decrementDiscardableUses(CGUser &uses) {
244 for (CallGraphNode *node : uses.topLevelUses)
245 --discardableSymNodeUses[node];
246 for (auto &it : uses.innerUses)
247 discardableSymNodeUses[it.first] -= it.second;
248}
249
250//===----------------------------------------------------------------------===//
251// CallGraph traversal
252//===----------------------------------------------------------------------===//
253
254namespace {
255/// This class represents a specific callgraph SCC.
256class CallGraphSCC {
257public:
258 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
259 : parentIterator(parentIterator) {}
260 /// Return a range over the nodes within this SCC.
261 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
262 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
263
264 /// Reset the nodes of this SCC with those provided.
265 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
266
267 /// Remove the given node from this SCC.
268 void remove(CallGraphNode *node) {
269 auto it = llvm::find(Range&: nodes, Val: node);
270 if (it != nodes.end()) {
271 nodes.erase(position: it);
272 parentIterator.ReplaceNode(Old: node, New: nullptr);
273 }
274 }
275
276private:
277 std::vector<CallGraphNode *> nodes;
278 llvm::scc_iterator<const CallGraph *> &parentIterator;
279};
280} // namespace
281
282/// Run a given transformation over the SCCs of the callgraph in a bottom up
283/// traversal.
284static LogicalResult runTransformOnCGSCCs(
285 const CallGraph &cg,
286 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
287 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(G: &cg);
288 CallGraphSCC currentSCC(cgi);
289 while (!cgi.isAtEnd()) {
290 // Copy the current SCC and increment so that the transformer can modify the
291 // SCC without invalidating our iterator.
292 currentSCC.reset(newNodes: *cgi);
293 ++cgi;
294 if (failed(Result: sccTransformer(currentSCC)))
295 return failure();
296 }
297 return success();
298}
299
300/// Collect all of the callable operations within the given range of blocks. If
301/// `traverseNestedCGNodes` is true, this will also collect call operations
302/// inside of nested callgraph nodes.
303static void collectCallOps(iterator_range<Region::iterator> blocks,
304 CallGraphNode *sourceNode, CallGraph &cg,
305 SymbolTableCollection &symbolTable,
306 SmallVectorImpl<ResolvedCall> &calls,
307 bool traverseNestedCGNodes) {
308 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
309 auto addToWorklist = [&](CallGraphNode *node,
310 iterator_range<Region::iterator> blocks) {
311 for (Block &block : blocks)
312 worklist.emplace_back(Args: &block, Args&: node);
313 };
314
315 addToWorklist(sourceNode, blocks);
316 while (!worklist.empty()) {
317 Block *block;
318 std::tie(args&: block, args&: sourceNode) = worklist.pop_back_val();
319
320 for (Operation &op : *block) {
321 if (auto call = dyn_cast<CallOpInterface>(Val&: op)) {
322 // TODO: Support inlining nested call references.
323 CallInterfaceCallable callable = call.getCallableForCallee();
324 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(Val&: callable)) {
325 if (!isa<FlatSymbolRefAttr>(Val: symRef))
326 continue;
327 }
328
329 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
330 if (!targetNode->isExternal())
331 calls.emplace_back(Args&: call, Args&: sourceNode, Args&: targetNode);
332 continue;
333 }
334
335 // If this is not a call, traverse the nested regions. If
336 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
337 // regions.
338 for (auto &nestedRegion : op.getRegions()) {
339 CallGraphNode *nestedNode = cg.lookupNode(region: &nestedRegion);
340 if (traverseNestedCGNodes || !nestedNode)
341 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
342 }
343 }
344 }
345}
346
347//===----------------------------------------------------------------------===//
348// InlinerInterfaceImpl
349//===----------------------------------------------------------------------===//
350
351#ifndef NDEBUG
352static std::string getNodeName(CallOpInterface op) {
353 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
354 return debugString(op);
355 return "_unnamed_callee_";
356}
357#endif
358
359/// Return true if the specified `inlineHistoryID` indicates an inline history
360/// that already includes `node`.
361static bool inlineHistoryIncludes(
362 CallGraphNode *node, std::optional<size_t> inlineHistoryID,
363 MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
364 inlineHistory) {
365 while (inlineHistoryID.has_value()) {
366 assert(*inlineHistoryID < inlineHistory.size() &&
367 "Invalid inline history ID");
368 if (inlineHistory[*inlineHistoryID].first == node)
369 return true;
370 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
371 }
372 return false;
373}
374
375namespace {
376/// This class provides a specialization of the main inlining interface.
377struct InlinerInterfaceImpl : public InlinerInterface {
378 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
379 SymbolTableCollection &symbolTable)
380 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
381
382 /// Process a set of blocks that have been inlined. This callback is invoked
383 /// *before* inlined terminator operations have been processed.
384 void
385 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
386 // Find the closest callgraph node from the first block.
387 CallGraphNode *node;
388 Region *region = inlinedBlocks.begin()->getParent();
389 while (!(node = cg.lookupNode(region))) {
390 region = region->getParentRegion();
391 assert(region && "expected valid parent node");
392 }
393
394 collectCallOps(blocks: inlinedBlocks, sourceNode: node, cg, symbolTable, calls,
395 /*traverseNestedCGNodes=*/true);
396 }
397
398 /// Mark the given callgraph node for deletion.
399 void markForDeletion(CallGraphNode *node) { deadNodes.insert(Ptr: node); }
400
401 /// This method properly disposes of callables that became dead during
402 /// inlining. This should not be called while iterating over the SCCs.
403 void eraseDeadCallables() {
404 for (CallGraphNode *node : deadNodes)
405 node->getCallableRegion()->getParentOp()->erase();
406 }
407
408 /// The set of callables known to be dead.
409 SmallPtrSet<CallGraphNode *, 8> deadNodes;
410
411 /// The current set of call instructions to consider for inlining.
412 SmallVector<ResolvedCall, 8> calls;
413
414 /// The callgraph being operated on.
415 CallGraph &cg;
416
417 /// A symbol table to use when resolving call lookups.
418 SymbolTableCollection &symbolTable;
419};
420} // namespace
421
422namespace mlir {
423
424class Inliner::Impl {
425public:
426 Impl(Inliner &inliner) : inliner(inliner) {}
427
428 /// Attempt to inline calls within the given scc, and run simplifications,
429 /// until a fixed point is reached. This allows for the inlining of newly
430 /// devirtualized calls. Returns failure if there was a fatal error during
431 /// inlining.
432 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
433 CGUseList &useList, CallGraphSCC &currentSCC,
434 MLIRContext *context);
435
436private:
437 /// Optimize the nodes within the given SCC with one of the held optimization
438 /// pass pipelines. Returns failure if an error occurred during the
439 /// optimization of the SCC, success otherwise.
440 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
441 CallGraphSCC &currentSCC, MLIRContext *context);
442
443 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
444 /// error occurred during the optimization of the SCC, success otherwise.
445 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
446 MLIRContext *context);
447
448 /// Optimize the given callable node with one of the pass managers provided
449 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
450 /// an error occurred during the optimization of the callable, success
451 /// otherwise.
452 LogicalResult optimizeCallable(CallGraphNode *node,
453 llvm::StringMap<OpPassManager> &pipelines);
454
455 /// Attempt to inline calls within the given scc. This function returns
456 /// success if any calls were inlined, failure otherwise.
457 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
458 CGUseList &useList, CallGraphSCC &currentSCC);
459
460 /// Returns true if the given call should be inlined.
461 bool shouldInline(ResolvedCall &resolvedCall);
462
463private:
464 Inliner &inliner;
465 llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
466};
467
468LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
469 CGUseList &useList,
470 CallGraphSCC &currentSCC,
471 MLIRContext *context) {
472 // Continuously simplify and inline until we either reach a fixed point, or
473 // hit the maximum iteration count. Simplifying early helps to refine the cost
474 // model, and in future iterations may devirtualize new calls.
475 unsigned iterationCount = 0;
476 do {
477 if (failed(Result: optimizeSCC(cg&: inlinerIface.cg, useList, currentSCC, context)))
478 return failure();
479 if (failed(Result: inlineCallsInSCC(inlinerIface, useList, currentSCC)))
480 break;
481 } while (++iterationCount < inliner.config.getMaxInliningIterations());
482 return success();
483}
484
485LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
486 CallGraphSCC &currentSCC,
487 MLIRContext *context) {
488 // Collect the sets of nodes to simplify.
489 SmallVector<CallGraphNode *, 4> nodesToVisit;
490 for (auto *node : currentSCC) {
491 if (node->isExternal())
492 continue;
493
494 // Don't simplify nodes with children. Nodes with children require special
495 // handling as we may remove the node during simplification. In the future,
496 // we should be able to handle this case with proper node deletion tracking.
497 if (node->hasChildren())
498 continue;
499
500 // We also won't apply simplifications to nodes that can't have passes
501 // scheduled on them.
502 auto *region = node->getCallableRegion();
503 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
504 continue;
505 nodesToVisit.push_back(Elt: node);
506 }
507 if (nodesToVisit.empty())
508 return success();
509
510 // Optimize each of the nodes within the SCC in parallel.
511 if (failed(Result: optimizeSCCAsync(nodesToVisit, context)))
512 return failure();
513
514 // Recompute the uses held by each of the nodes.
515 for (CallGraphNode *node : nodesToVisit)
516 useList.recomputeUses(node, cg);
517 return success();
518}
519
520LogicalResult
521Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
522 MLIRContext *ctx) {
523 // We must maintain a fixed pool of pass managers which is at least as large
524 // as the maximum parallelism of the failableParallelForEach below.
525 // Note: The number of pass managers here needs to remain constant
526 // to prevent issues with pass instrumentations that rely on having the same
527 // pass manager for the main thread.
528 size_t numThreads = ctx->getNumThreads();
529 const auto &opPipelines = inliner.config.getOpPipelines();
530 if (pipelines.size() < numThreads) {
531 pipelines.reserve(N: numThreads);
532 pipelines.resize(N: numThreads, NV: opPipelines);
533 }
534
535 // Ensure an analysis manager has been constructed for each of the nodes.
536 // This prevents thread races when running the nested pipelines.
537 for (CallGraphNode *node : nodesToVisit)
538 inliner.am.nest(op: node->getCallableRegion()->getParentOp());
539
540 // An atomic failure variable for the async executors.
541 std::vector<std::atomic<bool>> activePMs(pipelines.size());
542 llvm::fill(Range&: activePMs, Value: false);
543 return failableParallelForEach(context: ctx, range&: nodesToVisit, func: [&](CallGraphNode *node) {
544 // Find a pass manager for this operation.
545 auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) {
546 bool expectedInactive = false;
547 return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true);
548 });
549 assert(it != activePMs.end() &&
550 "could not find inactive pass manager for thread");
551 unsigned pmIndex = it - activePMs.begin();
552
553 // Optimize this callable node.
554 LogicalResult result = optimizeCallable(node, pipelines&: pipelines[pmIndex]);
555
556 // Reset the active bit for this pass manager.
557 activePMs[pmIndex].store(i: false);
558 return result;
559 });
560}
561
562LogicalResult
563Inliner::Impl::optimizeCallable(CallGraphNode *node,
564 llvm::StringMap<OpPassManager> &pipelines) {
565 Operation *callable = node->getCallableRegion()->getParentOp();
566 StringRef opName = callable->getName().getStringRef();
567 auto pipelineIt = pipelines.find(Key: opName);
568 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
569 if (pipelineIt == pipelines.end()) {
570 // If a pipeline didn't exist, use the generic pipeline if possible.
571 if (!defaultPipeline)
572 return success();
573
574 OpPassManager defaultPM(opName);
575 defaultPipeline(defaultPM);
576 pipelineIt = pipelines.try_emplace(Key: opName, Args: std::move(defaultPM)).first;
577 }
578 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
579}
580
581/// Attempt to inline calls within the given scc. This function returns
582/// success if any calls were inlined, failure otherwise.
583LogicalResult
584Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
585 CGUseList &useList, CallGraphSCC &currentSCC) {
586 CallGraph &cg = inlinerIface.cg;
587 auto &calls = inlinerIface.calls;
588
589 // A set of dead nodes to remove after inlining.
590 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
591
592 // Collect all of the direct calls within the nodes of the current SCC. We
593 // don't traverse nested callgraph nodes, because they are handled separately
594 // likely within a different SCC.
595 for (CallGraphNode *node : currentSCC) {
596 if (node->isExternal())
597 continue;
598
599 // Don't collect calls if the node is already dead.
600 if (useList.isDead(node)) {
601 deadNodes.insert(X: node);
602 } else {
603 collectCallOps(blocks: *node->getCallableRegion(), sourceNode: node, cg,
604 symbolTable&: inlinerIface.symbolTable, calls,
605 /*traverseNestedCGNodes=*/false);
606 }
607 }
608
609 // When inlining a callee produces new call sites, we want to keep track of
610 // the fact that they were inlined from the callee. This allows us to avoid
611 // infinite inlining.
612 using InlineHistoryT = std::optional<size_t>;
613 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
614 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
615
616 LLVM_DEBUG({
617 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
618 for (unsigned i = 0, e = calls.size(); i < e; ++i)
619 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
620 llvm::dbgs() << "}\n";
621 });
622
623 // Try to inline each of the call operations. Don't cache the end iterator
624 // here as more calls may be added during inlining.
625 bool inlinedAnyCalls = false;
626 for (unsigned i = 0; i < calls.size(); ++i) {
627 if (deadNodes.contains(key: calls[i].sourceNode))
628 continue;
629 ResolvedCall it = calls[i];
630
631 InlineHistoryT inlineHistoryID = callHistory[i];
632 bool inHistory =
633 inlineHistoryIncludes(node: it.targetNode, inlineHistoryID, inlineHistory);
634 bool doInline = !inHistory && shouldInline(resolvedCall&: it);
635 CallOpInterface call = it.call;
636 LLVM_DEBUG({
637 if (doInline)
638 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
639 else
640 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
641 });
642 if (!doInline)
643 continue;
644
645 unsigned prevSize = calls.size();
646 Region *targetRegion = it.targetNode->getCallableRegion();
647
648 // If this is the last call to the target node and the node is discardable,
649 // then inline it in-place and delete the node if successful.
650 bool inlineInPlace = useList.hasOneUseAndDiscardable(node: it.targetNode);
651
652 LogicalResult inlineResult =
653 inlineCall(interface&: inlinerIface, cloneCallback: inliner.config.getCloneCallback(), call,
654 callable: cast<CallableOpInterface>(Val: targetRegion->getParentOp()),
655 src: targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
656 if (failed(Result: inlineResult)) {
657 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
658 continue;
659 }
660 inlinedAnyCalls = true;
661
662 // Create a inline history entry for this inlined call, so that we remember
663 // that new callsites came about due to inlining Callee.
664 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
665 inlineHistory.push_back(Elt: std::make_pair(x&: it.targetNode, y&: inlineHistoryID));
666
667 auto historyToString = [](InlineHistoryT h) {
668 return h.has_value() ? std::to_string(val: *h) : "root";
669 };
670 (void)historyToString;
671 LLVM_DEBUG(llvm::dbgs()
672 << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
673 << getNodeName(call) << ", " << historyToString(inlineHistoryID)
674 << "]\n");
675
676 for (unsigned k = prevSize; k != calls.size(); ++k) {
677 callHistory.push_back(x: newInlineHistoryID);
678 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
679 << "}\n with historyID = " << newInlineHistoryID
680 << ", added due to inlining of\n call {" << call
681 << "}\n with historyID = "
682 << historyToString(inlineHistoryID) << "\n");
683 }
684
685 // If the inlining was successful, Merge the new uses into the source node.
686 useList.dropCallUses(userNode: it.sourceNode, callOp: call.getOperation(), cg);
687 useList.mergeUsesAfterInlining(lhs: it.targetNode, rhs: it.sourceNode);
688
689 // then erase the call.
690 call.erase();
691
692 // If we inlined in place, mark the node for deletion.
693 if (inlineInPlace) {
694 useList.eraseNode(node: it.targetNode);
695 deadNodes.insert(X: it.targetNode);
696 }
697 }
698
699 for (CallGraphNode *node : deadNodes) {
700 currentSCC.remove(node);
701 inlinerIface.markForDeletion(node);
702 }
703 calls.clear();
704 return success(IsSuccess: inlinedAnyCalls);
705}
706
707/// Returns true if the given call should be inlined.
708bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
709 // Don't allow inlining terminator calls. We currently don't support this
710 // case.
711 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
712 return false;
713
714 // Don't allow inlining if the target is a self-recursive function.
715 // Don't allow inlining if the call graph is like A->B->A.
716 if (llvm::count_if(Range&: *resolvedCall.targetNode,
717 P: [&](CallGraphNode::Edge const &edge) -> bool {
718 return edge.getTarget() == resolvedCall.targetNode ||
719 edge.getTarget() == resolvedCall.sourceNode;
720 }) > 0)
721 return false;
722
723 // Don't allow inlining if the target is an ancestor of the call. This
724 // prevents inlining recursively.
725 Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
726 if (callableRegion->isAncestor(other: resolvedCall.call->getParentRegion()))
727 return false;
728
729 // Don't allow inlining if the callee has multiple blocks (unstructured
730 // control flow) but we cannot be sure that the caller region supports that.
731 if (!inliner.config.getCanHandleMultipleBlocks()) {
732 bool calleeHasMultipleBlocks =
733 llvm::hasNItemsOrMore(C&: *callableRegion, /*N=*/2);
734 // If both parent ops have the same type, it is safe to inline. Otherwise,
735 // decide based on whether the op has the SingleBlock trait or not.
736 // Note: This check does currently not account for
737 // SizedRegion/MaxSizedRegion.
738 auto callerRegionSupportsMultipleBlocks = [&]() {
739 return callableRegion->getParentOp()->getName() ==
740 resolvedCall.call->getParentOp()->getName() ||
741 !resolvedCall.call->getParentOp()
742 ->mightHaveTrait<OpTrait::SingleBlock>();
743 };
744 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
745 return false;
746 }
747
748 if (!inliner.isProfitableToInline(resolvedCall))
749 return false;
750
751 // Otherwise, inline.
752 return true;
753}
754
755LogicalResult Inliner::doInlining() {
756 Impl impl(*this);
757 auto *context = op->getContext();
758 // Run the inline transform in post-order over the SCCs in the callgraph.
759 SymbolTableCollection symbolTable;
760 // FIXME: some clean-up can be done for the arguments
761 // of the Impl's methods, if the inlinerIface and useList
762 // become the states of the Impl.
763 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
764 CGUseList useList(op, cg, symbolTable);
765 LogicalResult result = runTransformOnCGSCCs(cg, sccTransformer: [&](CallGraphSCC &scc) {
766 return impl.inlineSCC(inlinerIface, useList, currentSCC&: scc, context);
767 });
768 if (failed(Result: result))
769 return result;
770
771 // After inlining, make sure to erase any callables proven to be dead.
772 inlinerIface.eraseDeadCallables();
773 return success();
774}
775} // namespace mlir
776

source code of mlir/lib/Transforms/Utils/Inliner.cpp