1//===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Inliner that uses a basic inlining
10// algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11// of the CallGraph. This enables a more incremental propagation of inlining
12// decisions from the leafs to the roots of the callgraph.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Transforms/Inliner.h"
17#include "mlir/IR/Threading.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/DebugStringHelper.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/SCCIterator.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/Support/Debug.h"
27
28#define DEBUG_TYPE "inlining"
29
30using namespace mlir;
31
32using ResolvedCall = Inliner::ResolvedCall;
33
34//===----------------------------------------------------------------------===//
35// Symbol Use Tracking
36//===----------------------------------------------------------------------===//
37
38/// Walk all of the used symbol callgraph nodes referenced with the given op.
39static void walkReferencedSymbolNodes(
40 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
41 DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
42 function_ref<void(CallGraphNode *, Operation *)> callback) {
43 auto symbolUses = SymbolTable::getSymbolUses(from: op);
44 assert(symbolUses && "expected uses to be valid");
45
46 Operation *symbolTableOp = op->getParentOp();
47 for (const SymbolTable::SymbolUse &use : *symbolUses) {
48 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
49 CallGraphNode *&node = refIt.first->second;
50
51 // If this is the first instance of this reference, try to resolve a
52 // callgraph node for it.
53 if (refIt.second) {
54 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(from: symbolTableOp,
55 symbol: use.getSymbolRef());
56 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57 if (!callableOp)
58 continue;
59 node = cg.lookupNode(region: callableOp.getCallableRegion());
60 }
61 if (node)
62 callback(node, use.getUser());
63 }
64}
65
66//===----------------------------------------------------------------------===//
67// CGUseList
68
69namespace {
70/// This struct tracks the uses of callgraph nodes that can be dropped when
71/// use_empty. It directly tracks and manages a use-list for all of the
72/// call-graph nodes. This is necessary because many callgraph nodes are
73/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
74/// class.
75struct CGUseList {
76 /// This struct tracks the uses of callgraph nodes within a specific
77 /// operation.
78 struct CGUser {
79 /// Any nodes referenced in the top-level attribute list of this user. We
80 /// use a set here because the number of references does not matter.
81 DenseSet<CallGraphNode *> topLevelUses;
82
83 /// Uses of nodes referenced by nested operations.
84 DenseMap<CallGraphNode *, int> innerUses;
85 };
86
87 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
88
89 /// Drop uses of nodes referred to by the given call operation that resides
90 /// within 'userNode'.
91 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
92
93 /// Remove the given node from the use list.
94 void eraseNode(CallGraphNode *node);
95
96 /// Returns true if the given callgraph node has no uses and can be pruned.
97 bool isDead(CallGraphNode *node) const;
98
99 /// Returns true if the given callgraph node has a single use and can be
100 /// discarded.
101 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
102
103 /// Recompute the uses held by the given callgraph node.
104 void recomputeUses(CallGraphNode *node, CallGraph &cg);
105
106 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
107 /// of 'lhs' into 'rhs'.
108 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
109
110private:
111 /// Decrement the uses of discardable nodes referenced by the given user.
112 void decrementDiscardableUses(CGUser &uses);
113
114 /// A mapping between a discardable callgraph node (that is a symbol) and the
115 /// number of uses for this node.
116 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
117
118 /// A mapping between a callgraph node and the symbol callgraph nodes that it
119 /// uses.
120 DenseMap<CallGraphNode *, CGUser> nodeUses;
121
122 /// A symbol table to use when resolving call lookups.
123 SymbolTableCollection &symbolTable;
124};
125} // namespace
126
127CGUseList::CGUseList(Operation *op, CallGraph &cg,
128 SymbolTableCollection &symbolTable)
129 : symbolTable(symbolTable) {
130 /// A set of callgraph nodes that are always known to be live during inlining.
131 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
132
133 // Walk each of the symbol tables looking for discardable callgraph nodes.
134 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
135 for (Operation &op : symbolTableOp->getRegion(index: 0).getOps()) {
136 // If this is a callgraph operation, check to see if it is discardable.
137 if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
138 if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
139 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
140 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
141 symbol.canDiscardOnUseEmpty()) {
142 discardableSymNodeUses.try_emplace(node, 0);
143 }
144 continue;
145 }
146 }
147 // Otherwise, check for any referenced nodes. These will be always-live.
148 walkReferencedSymbolNodes(op: &op, cg, symbolTable, resolvedRefs&: alwaysLiveNodes,
149 callback: [](CallGraphNode *, Operation *) {});
150 }
151 };
152 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
153 callback: walkFn);
154
155 // Drop the use information for any discardable nodes that are always live.
156 for (auto &it : alwaysLiveNodes)
157 discardableSymNodeUses.erase(Val: it.second);
158
159 // Compute the uses for each of the callable nodes in the graph.
160 for (CallGraphNode *node : cg)
161 recomputeUses(node, cg);
162}
163
164void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
165 CallGraph &cg) {
166 auto &userRefs = nodeUses[userNode].innerUses;
167 auto walkFn = [&](CallGraphNode *node, Operation *user) {
168 auto parentIt = userRefs.find(Val: node);
169 if (parentIt == userRefs.end())
170 return;
171 --parentIt->second;
172 --discardableSymNodeUses[node];
173 };
174 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
175 walkReferencedSymbolNodes(op: callOp, cg, symbolTable, resolvedRefs, callback: walkFn);
176}
177
178void CGUseList::eraseNode(CallGraphNode *node) {
179 // Drop all child nodes.
180 for (auto &edge : *node)
181 if (edge.isChild())
182 eraseNode(node: edge.getTarget());
183
184 // Drop the uses held by this node and erase it.
185 auto useIt = nodeUses.find(Val: node);
186 assert(useIt != nodeUses.end() && "expected node to be valid");
187 decrementDiscardableUses(uses&: useIt->getSecond());
188 nodeUses.erase(I: useIt);
189 discardableSymNodeUses.erase(Val: node);
190}
191
192bool CGUseList::isDead(CallGraphNode *node) const {
193 // If the parent operation isn't a symbol, simply check normal SSA deadness.
194 Operation *nodeOp = node->getCallableRegion()->getParentOp();
195 if (!isa<SymbolOpInterface>(nodeOp))
196 return isMemoryEffectFree(op: nodeOp) && nodeOp->use_empty();
197
198 // Otherwise, check the number of symbol uses.
199 auto symbolIt = discardableSymNodeUses.find(Val: node);
200 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
201}
202
203bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
204 // If this isn't a symbol node, check for side-effects and SSA use count.
205 Operation *nodeOp = node->getCallableRegion()->getParentOp();
206 if (!isa<SymbolOpInterface>(nodeOp))
207 return isMemoryEffectFree(op: nodeOp) && nodeOp->hasOneUse();
208
209 // Otherwise, check the number of symbol uses.
210 auto symbolIt = discardableSymNodeUses.find(Val: node);
211 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
212}
213
214void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
215 Operation *parentOp = node->getCallableRegion()->getParentOp();
216 CGUser &uses = nodeUses[node];
217 decrementDiscardableUses(uses);
218
219 // Collect the new discardable uses within this node.
220 uses = CGUser();
221 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
222 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
223 auto discardSymIt = discardableSymNodeUses.find(Val: refNode);
224 if (discardSymIt == discardableSymNodeUses.end())
225 return;
226
227 if (user != parentOp)
228 ++uses.innerUses[refNode];
229 else if (!uses.topLevelUses.insert(V: refNode).second)
230 return;
231 ++discardSymIt->second;
232 };
233 walkReferencedSymbolNodes(op: parentOp, cg, symbolTable, resolvedRefs, callback: walkFn);
234}
235
236void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
237 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
238 for (auto &useIt : lhsUses.innerUses) {
239 rhsUses.innerUses[useIt.first] += useIt.second;
240 discardableSymNodeUses[useIt.first] += useIt.second;
241 }
242}
243
244void CGUseList::decrementDiscardableUses(CGUser &uses) {
245 for (CallGraphNode *node : uses.topLevelUses)
246 --discardableSymNodeUses[node];
247 for (auto &it : uses.innerUses)
248 discardableSymNodeUses[it.first] -= it.second;
249}
250
251//===----------------------------------------------------------------------===//
252// CallGraph traversal
253//===----------------------------------------------------------------------===//
254
255namespace {
256/// This class represents a specific callgraph SCC.
257class CallGraphSCC {
258public:
259 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
260 : parentIterator(parentIterator) {}
261 /// Return a range over the nodes within this SCC.
262 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
263 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
264
265 /// Reset the nodes of this SCC with those provided.
266 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
267
268 /// Remove the given node from this SCC.
269 void remove(CallGraphNode *node) {
270 auto it = llvm::find(Range&: nodes, Val: node);
271 if (it != nodes.end()) {
272 nodes.erase(position: it);
273 parentIterator.ReplaceNode(Old: node, New: nullptr);
274 }
275 }
276
277private:
278 std::vector<CallGraphNode *> nodes;
279 llvm::scc_iterator<const CallGraph *> &parentIterator;
280};
281} // namespace
282
283/// Run a given transformation over the SCCs of the callgraph in a bottom up
284/// traversal.
285static LogicalResult runTransformOnCGSCCs(
286 const CallGraph &cg,
287 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
288 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(G: &cg);
289 CallGraphSCC currentSCC(cgi);
290 while (!cgi.isAtEnd()) {
291 // Copy the current SCC and increment so that the transformer can modify the
292 // SCC without invalidating our iterator.
293 currentSCC.reset(newNodes: *cgi);
294 ++cgi;
295 if (failed(result: sccTransformer(currentSCC)))
296 return failure();
297 }
298 return success();
299}
300
301/// Collect all of the callable operations within the given range of blocks. If
302/// `traverseNestedCGNodes` is true, this will also collect call operations
303/// inside of nested callgraph nodes.
304static void collectCallOps(iterator_range<Region::iterator> blocks,
305 CallGraphNode *sourceNode, CallGraph &cg,
306 SymbolTableCollection &symbolTable,
307 SmallVectorImpl<ResolvedCall> &calls,
308 bool traverseNestedCGNodes) {
309 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
310 auto addToWorklist = [&](CallGraphNode *node,
311 iterator_range<Region::iterator> blocks) {
312 for (Block &block : blocks)
313 worklist.emplace_back(Args: &block, Args&: node);
314 };
315
316 addToWorklist(sourceNode, blocks);
317 while (!worklist.empty()) {
318 Block *block;
319 std::tie(args&: block, args&: sourceNode) = worklist.pop_back_val();
320
321 for (Operation &op : *block) {
322 if (auto call = dyn_cast<CallOpInterface>(op)) {
323 // TODO: Support inlining nested call references.
324 CallInterfaceCallable callable = call.getCallableForCallee();
325 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
326 if (!isa<FlatSymbolRefAttr>(symRef))
327 continue;
328 }
329
330 CallGraphNode *targetNode = cg.resolveCallable(call: call, symbolTable);
331 if (!targetNode->isExternal())
332 calls.emplace_back(call, sourceNode, targetNode);
333 continue;
334 }
335
336 // If this is not a call, traverse the nested regions. If
337 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
338 // regions.
339 for (auto &nestedRegion : op.getRegions()) {
340 CallGraphNode *nestedNode = cg.lookupNode(region: &nestedRegion);
341 if (traverseNestedCGNodes || !nestedNode)
342 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
343 }
344 }
345 }
346}
347
348//===----------------------------------------------------------------------===//
349// InlinerInterfaceImpl
350//===----------------------------------------------------------------------===//
351
352#ifndef NDEBUG
353static std::string getNodeName(CallOpInterface op) {
354 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
355 return debugString(op);
356 return "_unnamed_callee_";
357}
358#endif
359
360/// Return true if the specified `inlineHistoryID` indicates an inline history
361/// that already includes `node`.
362static bool inlineHistoryIncludes(
363 CallGraphNode *node, std::optional<size_t> inlineHistoryID,
364 MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
365 inlineHistory) {
366 while (inlineHistoryID.has_value()) {
367 assert(*inlineHistoryID < inlineHistory.size() &&
368 "Invalid inline history ID");
369 if (inlineHistory[*inlineHistoryID].first == node)
370 return true;
371 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
372 }
373 return false;
374}
375
376namespace {
377/// This class provides a specialization of the main inlining interface.
378struct InlinerInterfaceImpl : public InlinerInterface {
379 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
380 SymbolTableCollection &symbolTable)
381 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
382
383 /// Process a set of blocks that have been inlined. This callback is invoked
384 /// *before* inlined terminator operations have been processed.
385 void
386 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
387 // Find the closest callgraph node from the first block.
388 CallGraphNode *node;
389 Region *region = inlinedBlocks.begin()->getParent();
390 while (!(node = cg.lookupNode(region))) {
391 region = region->getParentRegion();
392 assert(region && "expected valid parent node");
393 }
394
395 collectCallOps(blocks: inlinedBlocks, sourceNode: node, cg, symbolTable, calls,
396 /*traverseNestedCGNodes=*/true);
397 }
398
399 /// Mark the given callgraph node for deletion.
400 void markForDeletion(CallGraphNode *node) { deadNodes.insert(Ptr: node); }
401
402 /// This method properly disposes of callables that became dead during
403 /// inlining. This should not be called while iterating over the SCCs.
404 void eraseDeadCallables() {
405 for (CallGraphNode *node : deadNodes)
406 node->getCallableRegion()->getParentOp()->erase();
407 }
408
409 /// The set of callables known to be dead.
410 SmallPtrSet<CallGraphNode *, 8> deadNodes;
411
412 /// The current set of call instructions to consider for inlining.
413 SmallVector<ResolvedCall, 8> calls;
414
415 /// The callgraph being operated on.
416 CallGraph &cg;
417
418 /// A symbol table to use when resolving call lookups.
419 SymbolTableCollection &symbolTable;
420};
421} // namespace
422
423namespace mlir {
424
425class Inliner::Impl {
426public:
427 Impl(Inliner &inliner) : inliner(inliner) {}
428
429 /// Attempt to inline calls within the given scc, and run simplifications,
430 /// until a fixed point is reached. This allows for the inlining of newly
431 /// devirtualized calls. Returns failure if there was a fatal error during
432 /// inlining.
433 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
434 CGUseList &useList, CallGraphSCC &currentSCC,
435 MLIRContext *context);
436
437private:
438 /// Optimize the nodes within the given SCC with one of the held optimization
439 /// pass pipelines. Returns failure if an error occurred during the
440 /// optimization of the SCC, success otherwise.
441 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
442 CallGraphSCC &currentSCC, MLIRContext *context);
443
444 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
445 /// error occurred during the optimization of the SCC, success otherwise.
446 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
447 MLIRContext *context);
448
449 /// Optimize the given callable node with one of the pass managers provided
450 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
451 /// an error occurred during the optimization of the callable, success
452 /// otherwise.
453 LogicalResult optimizeCallable(CallGraphNode *node,
454 llvm::StringMap<OpPassManager> &pipelines);
455
456 /// Attempt to inline calls within the given scc. This function returns
457 /// success if any calls were inlined, failure otherwise.
458 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
459 CGUseList &useList, CallGraphSCC &currentSCC);
460
461 /// Returns true if the given call should be inlined.
462 bool shouldInline(ResolvedCall &resolvedCall);
463
464private:
465 Inliner &inliner;
466 llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
467};
468
469LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
470 CGUseList &useList,
471 CallGraphSCC &currentSCC,
472 MLIRContext *context) {
473 // Continuously simplify and inline until we either reach a fixed point, or
474 // hit the maximum iteration count. Simplifying early helps to refine the cost
475 // model, and in future iterations may devirtualize new calls.
476 unsigned iterationCount = 0;
477 do {
478 if (failed(result: optimizeSCC(cg&: inlinerIface.cg, useList, currentSCC, context)))
479 return failure();
480 if (failed(result: inlineCallsInSCC(inlinerIface, useList, currentSCC)))
481 break;
482 } while (++iterationCount < inliner.config.getMaxInliningIterations());
483 return success();
484}
485
486LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
487 CallGraphSCC &currentSCC,
488 MLIRContext *context) {
489 // Collect the sets of nodes to simplify.
490 SmallVector<CallGraphNode *, 4> nodesToVisit;
491 for (auto *node : currentSCC) {
492 if (node->isExternal())
493 continue;
494
495 // Don't simplify nodes with children. Nodes with children require special
496 // handling as we may remove the node during simplification. In the future,
497 // we should be able to handle this case with proper node deletion tracking.
498 if (node->hasChildren())
499 continue;
500
501 // We also won't apply simplifications to nodes that can't have passes
502 // scheduled on them.
503 auto *region = node->getCallableRegion();
504 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
505 continue;
506 nodesToVisit.push_back(Elt: node);
507 }
508 if (nodesToVisit.empty())
509 return success();
510
511 // Optimize each of the nodes within the SCC in parallel.
512 if (failed(result: optimizeSCCAsync(nodesToVisit, context)))
513 return failure();
514
515 // Recompute the uses held by each of the nodes.
516 for (CallGraphNode *node : nodesToVisit)
517 useList.recomputeUses(node, cg);
518 return success();
519}
520
521LogicalResult
522Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
523 MLIRContext *ctx) {
524 // We must maintain a fixed pool of pass managers which is at least as large
525 // as the maximum parallelism of the failableParallelForEach below.
526 // Note: The number of pass managers here needs to remain constant
527 // to prevent issues with pass instrumentations that rely on having the same
528 // pass manager for the main thread.
529 size_t numThreads = ctx->getNumThreads();
530 const auto &opPipelines = inliner.config.getOpPipelines();
531 if (pipelines.size() < numThreads) {
532 pipelines.reserve(N: numThreads);
533 pipelines.resize(N: numThreads, NV: opPipelines);
534 }
535
536 // Ensure an analysis manager has been constructed for each of the nodes.
537 // This prevents thread races when running the nested pipelines.
538 for (CallGraphNode *node : nodesToVisit)
539 inliner.am.nest(op: node->getCallableRegion()->getParentOp());
540
541 // An atomic failure variable for the async executors.
542 std::vector<std::atomic<bool>> activePMs(pipelines.size());
543 std::fill(first: activePMs.begin(), last: activePMs.end(), value: false);
544 return failableParallelForEach(context: ctx, range&: nodesToVisit, func: [&](CallGraphNode *node) {
545 // Find a pass manager for this operation.
546 auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) {
547 bool expectedInactive = false;
548 return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true);
549 });
550 assert(it != activePMs.end() &&
551 "could not find inactive pass manager for thread");
552 unsigned pmIndex = it - activePMs.begin();
553
554 // Optimize this callable node.
555 LogicalResult result = optimizeCallable(node, pipelines&: pipelines[pmIndex]);
556
557 // Reset the active bit for this pass manager.
558 activePMs[pmIndex].store(i: false);
559 return result;
560 });
561}
562
563LogicalResult
564Inliner::Impl::optimizeCallable(CallGraphNode *node,
565 llvm::StringMap<OpPassManager> &pipelines) {
566 Operation *callable = node->getCallableRegion()->getParentOp();
567 StringRef opName = callable->getName().getStringRef();
568 auto pipelineIt = pipelines.find(Key: opName);
569 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
570 if (pipelineIt == pipelines.end()) {
571 // If a pipeline didn't exist, use the generic pipeline if possible.
572 if (!defaultPipeline)
573 return success();
574
575 OpPassManager defaultPM(opName);
576 defaultPipeline(defaultPM);
577 pipelineIt = pipelines.try_emplace(Key: opName, Args: std::move(defaultPM)).first;
578 }
579 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
580}
581
582/// Attempt to inline calls within the given scc. This function returns
583/// success if any calls were inlined, failure otherwise.
584LogicalResult
585Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
586 CGUseList &useList, CallGraphSCC &currentSCC) {
587 CallGraph &cg = inlinerIface.cg;
588 auto &calls = inlinerIface.calls;
589
590 // A set of dead nodes to remove after inlining.
591 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
592
593 // Collect all of the direct calls within the nodes of the current SCC. We
594 // don't traverse nested callgraph nodes, because they are handled separately
595 // likely within a different SCC.
596 for (CallGraphNode *node : currentSCC) {
597 if (node->isExternal())
598 continue;
599
600 // Don't collect calls if the node is already dead.
601 if (useList.isDead(node)) {
602 deadNodes.insert(X: node);
603 } else {
604 collectCallOps(blocks: *node->getCallableRegion(), sourceNode: node, cg,
605 symbolTable&: inlinerIface.symbolTable, calls,
606 /*traverseNestedCGNodes=*/false);
607 }
608 }
609
610 // When inlining a callee produces new call sites, we want to keep track of
611 // the fact that they were inlined from the callee. This allows us to avoid
612 // infinite inlining.
613 using InlineHistoryT = std::optional<size_t>;
614 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
615 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
616
617 LLVM_DEBUG({
618 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
619 for (unsigned i = 0, e = calls.size(); i < e; ++i)
620 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
621 llvm::dbgs() << "}\n";
622 });
623
624 // Try to inline each of the call operations. Don't cache the end iterator
625 // here as more calls may be added during inlining.
626 bool inlinedAnyCalls = false;
627 for (unsigned i = 0; i < calls.size(); ++i) {
628 if (deadNodes.contains(key: calls[i].sourceNode))
629 continue;
630 ResolvedCall it = calls[i];
631
632 InlineHistoryT inlineHistoryID = callHistory[i];
633 bool inHistory =
634 inlineHistoryIncludes(node: it.targetNode, inlineHistoryID, inlineHistory);
635 bool doInline = !inHistory && shouldInline(resolvedCall&: it);
636 CallOpInterface call = it.call;
637 LLVM_DEBUG({
638 if (doInline)
639 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
640 else
641 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
642 });
643 if (!doInline)
644 continue;
645
646 unsigned prevSize = calls.size();
647 Region *targetRegion = it.targetNode->getCallableRegion();
648
649 // If this is the last call to the target node and the node is discardable,
650 // then inline it in-place and delete the node if successful.
651 bool inlineInPlace = useList.hasOneUseAndDiscardable(node: it.targetNode);
652
653 LogicalResult inlineResult =
654 inlineCall(inlinerIface, call,
655 cast<CallableOpInterface>(targetRegion->getParentOp()),
656 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
657 if (failed(result: inlineResult)) {
658 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
659 continue;
660 }
661 inlinedAnyCalls = true;
662
663 // Create a inline history entry for this inlined call, so that we remember
664 // that new callsites came about due to inlining Callee.
665 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
666 inlineHistory.push_back(Elt: std::make_pair(x&: it.targetNode, y&: inlineHistoryID));
667
668 auto historyToString = [](InlineHistoryT h) {
669 return h.has_value() ? std::to_string(val: *h) : "root";
670 };
671 (void)historyToString;
672 LLVM_DEBUG(llvm::dbgs()
673 << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
674 << getNodeName(call) << ", " << historyToString(inlineHistoryID)
675 << "]\n");
676
677 for (unsigned k = prevSize; k != calls.size(); ++k) {
678 callHistory.push_back(x: newInlineHistoryID);
679 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
680 << "}\n with historyID = " << newInlineHistoryID
681 << ", added due to inlining of\n call {" << call
682 << "}\n with historyID = "
683 << historyToString(inlineHistoryID) << "\n");
684 }
685
686 // If the inlining was successful, Merge the new uses into the source node.
687 useList.dropCallUses(userNode: it.sourceNode, callOp: call.getOperation(), cg);
688 useList.mergeUsesAfterInlining(lhs: it.targetNode, rhs: it.sourceNode);
689
690 // then erase the call.
691 call.erase();
692
693 // If we inlined in place, mark the node for deletion.
694 if (inlineInPlace) {
695 useList.eraseNode(node: it.targetNode);
696 deadNodes.insert(X: it.targetNode);
697 }
698 }
699
700 for (CallGraphNode *node : deadNodes) {
701 currentSCC.remove(node);
702 inlinerIface.markForDeletion(node);
703 }
704 calls.clear();
705 return success(isSuccess: inlinedAnyCalls);
706}
707
708/// Returns true if the given call should be inlined.
709bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
710 // Don't allow inlining terminator calls. We currently don't support this
711 // case.
712 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
713 return false;
714
715 // Don't allow inlining if the target is a self-recursive function.
716 if (llvm::count_if(Range&: *resolvedCall.targetNode,
717 P: [&](CallGraphNode::Edge const &edge) -> bool {
718 return edge.getTarget() == resolvedCall.targetNode;
719 }) > 0)
720 return false;
721
722 // Don't allow inlining if the target is an ancestor of the call. This
723 // prevents inlining recursively.
724 Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
725 if (callableRegion->isAncestor(other: resolvedCall.call->getParentRegion()))
726 return false;
727
728 // Don't allow inlining if the callee has multiple blocks (unstructured
729 // control flow) but we cannot be sure that the caller region supports that.
730 bool calleeHasMultipleBlocks =
731 llvm::hasNItemsOrMore(C&: *callableRegion, /*N=*/2);
732 // If both parent ops have the same type, it is safe to inline. Otherwise,
733 // decide based on whether the op has the SingleBlock trait or not.
734 // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
735 auto callerRegionSupportsMultipleBlocks = [&]() {
736 return callableRegion->getParentOp()->getName() ==
737 resolvedCall.call->getParentOp()->getName() ||
738 !resolvedCall.call->getParentOp()
739 ->mightHaveTrait<OpTrait::SingleBlock>();
740 };
741 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
742 return false;
743
744 if (!inliner.isProfitableToInline(resolvedCall))
745 return false;
746
747 // Otherwise, inline.
748 return true;
749}
750
751LogicalResult Inliner::doInlining() {
752 Impl impl(*this);
753 auto *context = op->getContext();
754 // Run the inline transform in post-order over the SCCs in the callgraph.
755 SymbolTableCollection symbolTable;
756 // FIXME: some clean-up can be done for the arguments
757 // of the Impl's methods, if the inlinerIface and useList
758 // become the states of the Impl.
759 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
760 CGUseList useList(op, cg, symbolTable);
761 LogicalResult result = runTransformOnCGSCCs(cg, sccTransformer: [&](CallGraphSCC &scc) {
762 return impl.inlineSCC(inlinerIface, useList, currentSCC&: scc, context);
763 });
764 if (failed(result))
765 return result;
766
767 // After inlining, make sure to erase any callables proven to be dead.
768 inlinerIface.eraseDeadCallables();
769 return success();
770}
771} // namespace mlir
772

source code of mlir/lib/Transforms/Utils/Inliner.cpp