1//===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Inliner that uses a basic inlining
10// algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11// of the CallGraph. This enables a more incremental propagation of inlining
12// decisions from the leafs to the roots of the callgraph.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Transforms/Inliner.h"
17#include "mlir/IR/Threading.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/DebugStringHelper.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/SCCIterator.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/Support/Debug.h"
27
28#define DEBUG_TYPE "inlining"
29
30using namespace mlir;
31
32using ResolvedCall = Inliner::ResolvedCall;
33
34//===----------------------------------------------------------------------===//
35// Symbol Use Tracking
36//===----------------------------------------------------------------------===//
37
38/// Walk all of the used symbol callgraph nodes referenced with the given op.
39static void walkReferencedSymbolNodes(
40 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
41 DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
42 function_ref<void(CallGraphNode *, Operation *)> callback) {
43 auto symbolUses = SymbolTable::getSymbolUses(from: op);
44 assert(symbolUses && "expected uses to be valid");
45
46 Operation *symbolTableOp = op->getParentOp();
47 for (const SymbolTable::SymbolUse &use : *symbolUses) {
48 auto refIt = resolvedRefs.try_emplace(use.getSymbolRef());
49 CallGraphNode *&node = refIt.first->second;
50
51 // If this is the first instance of this reference, try to resolve a
52 // callgraph node for it.
53 if (refIt.second) {
54 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
55 use.getSymbolRef());
56 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57 if (!callableOp)
58 continue;
59 node = cg.lookupNode(region: callableOp.getCallableRegion());
60 }
61 if (node)
62 callback(node, use.getUser());
63 }
64}
65
66//===----------------------------------------------------------------------===//
67// CGUseList
68//===----------------------------------------------------------------------===//
69
70namespace {
71/// This struct tracks the uses of callgraph nodes that can be dropped when
72/// use_empty. It directly tracks and manages a use-list for all of the
73/// call-graph nodes. This is necessary because many callgraph nodes are
74/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
75/// class.
76struct CGUseList {
77 /// This struct tracks the uses of callgraph nodes within a specific
78 /// operation.
79 struct CGUser {
80 /// Any nodes referenced in the top-level attribute list of this user. We
81 /// use a set here because the number of references does not matter.
82 DenseSet<CallGraphNode *> topLevelUses;
83
84 /// Uses of nodes referenced by nested operations.
85 DenseMap<CallGraphNode *, int> innerUses;
86 };
87
88 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
89
90 /// Drop uses of nodes referred to by the given call operation that resides
91 /// within 'userNode'.
92 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
93
94 /// Remove the given node from the use list.
95 void eraseNode(CallGraphNode *node);
96
97 /// Returns true if the given callgraph node has no uses and can be pruned.
98 bool isDead(CallGraphNode *node) const;
99
100 /// Returns true if the given callgraph node has a single use and can be
101 /// discarded.
102 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
103
104 /// Recompute the uses held by the given callgraph node.
105 void recomputeUses(CallGraphNode *node, CallGraph &cg);
106
107 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
108 /// of 'lhs' into 'rhs'.
109 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
110
111private:
112 /// Decrement the uses of discardable nodes referenced by the given user.
113 void decrementDiscardableUses(CGUser &uses);
114
115 /// A mapping between a discardable callgraph node (that is a symbol) and the
116 /// number of uses for this node.
117 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
118
119 /// A mapping between a callgraph node and the symbol callgraph nodes that it
120 /// uses.
121 DenseMap<CallGraphNode *, CGUser> nodeUses;
122
123 /// A symbol table to use when resolving call lookups.
124 SymbolTableCollection &symbolTable;
125};
126} // namespace
127
128CGUseList::CGUseList(Operation *op, CallGraph &cg,
129 SymbolTableCollection &symbolTable)
130 : symbolTable(symbolTable) {
131 /// A set of callgraph nodes that are always known to be live during inlining.
132 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
133
134 // Walk each of the symbol tables looking for discardable callgraph nodes.
135 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
136 for (Operation &op : symbolTableOp->getRegion(index: 0).getOps()) {
137 // If this is a callgraph operation, check to see if it is discardable.
138 if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
139 if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
140 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
141 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
142 symbol.canDiscardOnUseEmpty()) {
143 discardableSymNodeUses.try_emplace(node, 0);
144 }
145 continue;
146 }
147 }
148 // Otherwise, check for any referenced nodes. These will be always-live.
149 walkReferencedSymbolNodes(op: &op, cg, symbolTable, resolvedRefs&: alwaysLiveNodes,
150 callback: [](CallGraphNode *, Operation *) {});
151 }
152 };
153 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
154 callback: walkFn);
155
156 // Drop the use information for any discardable nodes that are always live.
157 for (auto &it : alwaysLiveNodes)
158 discardableSymNodeUses.erase(Val: it.second);
159
160 // Compute the uses for each of the callable nodes in the graph.
161 for (CallGraphNode *node : cg)
162 recomputeUses(node, cg);
163}
164
165void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
166 CallGraph &cg) {
167 auto &userRefs = nodeUses[userNode].innerUses;
168 auto walkFn = [&](CallGraphNode *node, Operation *user) {
169 auto parentIt = userRefs.find(Val: node);
170 if (parentIt == userRefs.end())
171 return;
172 --parentIt->second;
173 --discardableSymNodeUses[node];
174 };
175 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
176 walkReferencedSymbolNodes(op: callOp, cg, symbolTable, resolvedRefs, callback: walkFn);
177}
178
179void CGUseList::eraseNode(CallGraphNode *node) {
180 // Drop all child nodes.
181 for (auto &edge : *node)
182 if (edge.isChild())
183 eraseNode(node: edge.getTarget());
184
185 // Drop the uses held by this node and erase it.
186 auto useIt = nodeUses.find(Val: node);
187 assert(useIt != nodeUses.end() && "expected node to be valid");
188 decrementDiscardableUses(uses&: useIt->getSecond());
189 nodeUses.erase(I: useIt);
190 discardableSymNodeUses.erase(Val: node);
191}
192
193bool CGUseList::isDead(CallGraphNode *node) const {
194 // If the parent operation isn't a symbol, simply check normal SSA deadness.
195 Operation *nodeOp = node->getCallableRegion()->getParentOp();
196 if (!isa<SymbolOpInterface>(nodeOp))
197 return isMemoryEffectFree(op: nodeOp) && nodeOp->use_empty();
198
199 // Otherwise, check the number of symbol uses.
200 auto symbolIt = discardableSymNodeUses.find(Val: node);
201 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
202}
203
204bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
205 // If this isn't a symbol node, check for side-effects and SSA use count.
206 Operation *nodeOp = node->getCallableRegion()->getParentOp();
207 if (!isa<SymbolOpInterface>(nodeOp))
208 return isMemoryEffectFree(op: nodeOp) && nodeOp->hasOneUse();
209
210 // Otherwise, check the number of symbol uses.
211 auto symbolIt = discardableSymNodeUses.find(Val: node);
212 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
213}
214
215void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
216 Operation *parentOp = node->getCallableRegion()->getParentOp();
217 CGUser &uses = nodeUses[node];
218 decrementDiscardableUses(uses);
219
220 // Collect the new discardable uses within this node.
221 uses = CGUser();
222 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
223 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
224 auto discardSymIt = discardableSymNodeUses.find(Val: refNode);
225 if (discardSymIt == discardableSymNodeUses.end())
226 return;
227
228 if (user != parentOp)
229 ++uses.innerUses[refNode];
230 else if (!uses.topLevelUses.insert(V: refNode).second)
231 return;
232 ++discardSymIt->second;
233 };
234 walkReferencedSymbolNodes(op: parentOp, cg, symbolTable, resolvedRefs, callback: walkFn);
235}
236
237void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
238 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
239 for (auto &useIt : lhsUses.innerUses) {
240 rhsUses.innerUses[useIt.first] += useIt.second;
241 discardableSymNodeUses[useIt.first] += useIt.second;
242 }
243}
244
245void CGUseList::decrementDiscardableUses(CGUser &uses) {
246 for (CallGraphNode *node : uses.topLevelUses)
247 --discardableSymNodeUses[node];
248 for (auto &it : uses.innerUses)
249 discardableSymNodeUses[it.first] -= it.second;
250}
251
252//===----------------------------------------------------------------------===//
253// CallGraph traversal
254//===----------------------------------------------------------------------===//
255
256namespace {
257/// This class represents a specific callgraph SCC.
258class CallGraphSCC {
259public:
260 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
261 : parentIterator(parentIterator) {}
262 /// Return a range over the nodes within this SCC.
263 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
264 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
265
266 /// Reset the nodes of this SCC with those provided.
267 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
268
269 /// Remove the given node from this SCC.
270 void remove(CallGraphNode *node) {
271 auto it = llvm::find(Range&: nodes, Val: node);
272 if (it != nodes.end()) {
273 nodes.erase(position: it);
274 parentIterator.ReplaceNode(Old: node, New: nullptr);
275 }
276 }
277
278private:
279 std::vector<CallGraphNode *> nodes;
280 llvm::scc_iterator<const CallGraph *> &parentIterator;
281};
282} // namespace
283
284/// Run a given transformation over the SCCs of the callgraph in a bottom up
285/// traversal.
286static LogicalResult runTransformOnCGSCCs(
287 const CallGraph &cg,
288 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
289 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(G: &cg);
290 CallGraphSCC currentSCC(cgi);
291 while (!cgi.isAtEnd()) {
292 // Copy the current SCC and increment so that the transformer can modify the
293 // SCC without invalidating our iterator.
294 currentSCC.reset(newNodes: *cgi);
295 ++cgi;
296 if (failed(Result: sccTransformer(currentSCC)))
297 return failure();
298 }
299 return success();
300}
301
302/// Collect all of the callable operations within the given range of blocks. If
303/// `traverseNestedCGNodes` is true, this will also collect call operations
304/// inside of nested callgraph nodes.
305static void collectCallOps(iterator_range<Region::iterator> blocks,
306 CallGraphNode *sourceNode, CallGraph &cg,
307 SymbolTableCollection &symbolTable,
308 SmallVectorImpl<ResolvedCall> &calls,
309 bool traverseNestedCGNodes) {
310 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
311 auto addToWorklist = [&](CallGraphNode *node,
312 iterator_range<Region::iterator> blocks) {
313 for (Block &block : blocks)
314 worklist.emplace_back(Args: &block, Args&: node);
315 };
316
317 addToWorklist(sourceNode, blocks);
318 while (!worklist.empty()) {
319 Block *block;
320 std::tie(args&: block, args&: sourceNode) = worklist.pop_back_val();
321
322 for (Operation &op : *block) {
323 if (auto call = dyn_cast<CallOpInterface>(op)) {
324 // TODO: Support inlining nested call references.
325 CallInterfaceCallable callable = call.getCallableForCallee();
326 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
327 if (!isa<FlatSymbolRefAttr>(symRef))
328 continue;
329 }
330
331 CallGraphNode *targetNode = cg.resolveCallable(call: call, symbolTable);
332 if (!targetNode->isExternal())
333 calls.emplace_back(call, sourceNode, targetNode);
334 continue;
335 }
336
337 // If this is not a call, traverse the nested regions. If
338 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
339 // regions.
340 for (auto &nestedRegion : op.getRegions()) {
341 CallGraphNode *nestedNode = cg.lookupNode(region: &nestedRegion);
342 if (traverseNestedCGNodes || !nestedNode)
343 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
344 }
345 }
346 }
347}
348
349//===----------------------------------------------------------------------===//
350// InlinerInterfaceImpl
351//===----------------------------------------------------------------------===//
352
353#ifndef NDEBUG
354static std::string getNodeName(CallOpInterface op) {
355 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
356 return debugString(op);
357 return "_unnamed_callee_";
358}
359#endif
360
361/// Return true if the specified `inlineHistoryID` indicates an inline history
362/// that already includes `node`.
363static bool inlineHistoryIncludes(
364 CallGraphNode *node, std::optional<size_t> inlineHistoryID,
365 MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
366 inlineHistory) {
367 while (inlineHistoryID.has_value()) {
368 assert(*inlineHistoryID < inlineHistory.size() &&
369 "Invalid inline history ID");
370 if (inlineHistory[*inlineHistoryID].first == node)
371 return true;
372 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
373 }
374 return false;
375}
376
377namespace {
378/// This class provides a specialization of the main inlining interface.
379struct InlinerInterfaceImpl : public InlinerInterface {
380 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
381 SymbolTableCollection &symbolTable)
382 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
383
384 /// Process a set of blocks that have been inlined. This callback is invoked
385 /// *before* inlined terminator operations have been processed.
386 void
387 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
388 // Find the closest callgraph node from the first block.
389 CallGraphNode *node;
390 Region *region = inlinedBlocks.begin()->getParent();
391 while (!(node = cg.lookupNode(region))) {
392 region = region->getParentRegion();
393 assert(region && "expected valid parent node");
394 }
395
396 collectCallOps(blocks: inlinedBlocks, sourceNode: node, cg, symbolTable, calls,
397 /*traverseNestedCGNodes=*/true);
398 }
399
400 /// Mark the given callgraph node for deletion.
401 void markForDeletion(CallGraphNode *node) { deadNodes.insert(Ptr: node); }
402
403 /// This method properly disposes of callables that became dead during
404 /// inlining. This should not be called while iterating over the SCCs.
405 void eraseDeadCallables() {
406 for (CallGraphNode *node : deadNodes)
407 node->getCallableRegion()->getParentOp()->erase();
408 }
409
410 /// The set of callables known to be dead.
411 SmallPtrSet<CallGraphNode *, 8> deadNodes;
412
413 /// The current set of call instructions to consider for inlining.
414 SmallVector<ResolvedCall, 8> calls;
415
416 /// The callgraph being operated on.
417 CallGraph &cg;
418
419 /// A symbol table to use when resolving call lookups.
420 SymbolTableCollection &symbolTable;
421};
422} // namespace
423
424namespace mlir {
425
426class Inliner::Impl {
427public:
428 Impl(Inliner &inliner) : inliner(inliner) {}
429
430 /// Attempt to inline calls within the given scc, and run simplifications,
431 /// until a fixed point is reached. This allows for the inlining of newly
432 /// devirtualized calls. Returns failure if there was a fatal error during
433 /// inlining.
434 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
435 CGUseList &useList, CallGraphSCC &currentSCC,
436 MLIRContext *context);
437
438private:
439 /// Optimize the nodes within the given SCC with one of the held optimization
440 /// pass pipelines. Returns failure if an error occurred during the
441 /// optimization of the SCC, success otherwise.
442 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
443 CallGraphSCC &currentSCC, MLIRContext *context);
444
445 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
446 /// error occurred during the optimization of the SCC, success otherwise.
447 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
448 MLIRContext *context);
449
450 /// Optimize the given callable node with one of the pass managers provided
451 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
452 /// an error occurred during the optimization of the callable, success
453 /// otherwise.
454 LogicalResult optimizeCallable(CallGraphNode *node,
455 llvm::StringMap<OpPassManager> &pipelines);
456
457 /// Attempt to inline calls within the given scc. This function returns
458 /// success if any calls were inlined, failure otherwise.
459 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
460 CGUseList &useList, CallGraphSCC &currentSCC);
461
462 /// Returns true if the given call should be inlined.
463 bool shouldInline(ResolvedCall &resolvedCall);
464
465private:
466 Inliner &inliner;
467 llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
468};
469
470LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
471 CGUseList &useList,
472 CallGraphSCC &currentSCC,
473 MLIRContext *context) {
474 // Continuously simplify and inline until we either reach a fixed point, or
475 // hit the maximum iteration count. Simplifying early helps to refine the cost
476 // model, and in future iterations may devirtualize new calls.
477 unsigned iterationCount = 0;
478 do {
479 if (failed(Result: optimizeSCC(cg&: inlinerIface.cg, useList, currentSCC, context)))
480 return failure();
481 if (failed(Result: inlineCallsInSCC(inlinerIface, useList, currentSCC)))
482 break;
483 } while (++iterationCount < inliner.config.getMaxInliningIterations());
484 return success();
485}
486
487LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
488 CallGraphSCC &currentSCC,
489 MLIRContext *context) {
490 // Collect the sets of nodes to simplify.
491 SmallVector<CallGraphNode *, 4> nodesToVisit;
492 for (auto *node : currentSCC) {
493 if (node->isExternal())
494 continue;
495
496 // Don't simplify nodes with children. Nodes with children require special
497 // handling as we may remove the node during simplification. In the future,
498 // we should be able to handle this case with proper node deletion tracking.
499 if (node->hasChildren())
500 continue;
501
502 // We also won't apply simplifications to nodes that can't have passes
503 // scheduled on them.
504 auto *region = node->getCallableRegion();
505 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
506 continue;
507 nodesToVisit.push_back(Elt: node);
508 }
509 if (nodesToVisit.empty())
510 return success();
511
512 // Optimize each of the nodes within the SCC in parallel.
513 if (failed(Result: optimizeSCCAsync(nodesToVisit, context)))
514 return failure();
515
516 // Recompute the uses held by each of the nodes.
517 for (CallGraphNode *node : nodesToVisit)
518 useList.recomputeUses(node, cg);
519 return success();
520}
521
522LogicalResult
523Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
524 MLIRContext *ctx) {
525 // We must maintain a fixed pool of pass managers which is at least as large
526 // as the maximum parallelism of the failableParallelForEach below.
527 // Note: The number of pass managers here needs to remain constant
528 // to prevent issues with pass instrumentations that rely on having the same
529 // pass manager for the main thread.
530 size_t numThreads = ctx->getNumThreads();
531 const auto &opPipelines = inliner.config.getOpPipelines();
532 if (pipelines.size() < numThreads) {
533 pipelines.reserve(N: numThreads);
534 pipelines.resize(N: numThreads, NV: opPipelines);
535 }
536
537 // Ensure an analysis manager has been constructed for each of the nodes.
538 // This prevents thread races when running the nested pipelines.
539 for (CallGraphNode *node : nodesToVisit)
540 inliner.am.nest(op: node->getCallableRegion()->getParentOp());
541
542 // An atomic failure variable for the async executors.
543 std::vector<std::atomic<bool>> activePMs(pipelines.size());
544 std::fill(first: activePMs.begin(), last: activePMs.end(), value: false);
545 return failableParallelForEach(context: ctx, range&: nodesToVisit, func: [&](CallGraphNode *node) {
546 // Find a pass manager for this operation.
547 auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) {
548 bool expectedInactive = false;
549 return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true);
550 });
551 assert(it != activePMs.end() &&
552 "could not find inactive pass manager for thread");
553 unsigned pmIndex = it - activePMs.begin();
554
555 // Optimize this callable node.
556 LogicalResult result = optimizeCallable(node, pipelines&: pipelines[pmIndex]);
557
558 // Reset the active bit for this pass manager.
559 activePMs[pmIndex].store(i: false);
560 return result;
561 });
562}
563
564LogicalResult
565Inliner::Impl::optimizeCallable(CallGraphNode *node,
566 llvm::StringMap<OpPassManager> &pipelines) {
567 Operation *callable = node->getCallableRegion()->getParentOp();
568 StringRef opName = callable->getName().getStringRef();
569 auto pipelineIt = pipelines.find(Key: opName);
570 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
571 if (pipelineIt == pipelines.end()) {
572 // If a pipeline didn't exist, use the generic pipeline if possible.
573 if (!defaultPipeline)
574 return success();
575
576 OpPassManager defaultPM(opName);
577 defaultPipeline(defaultPM);
578 pipelineIt = pipelines.try_emplace(Key: opName, Args: std::move(defaultPM)).first;
579 }
580 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
581}
582
583/// Attempt to inline calls within the given scc. This function returns
584/// success if any calls were inlined, failure otherwise.
585LogicalResult
586Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
587 CGUseList &useList, CallGraphSCC &currentSCC) {
588 CallGraph &cg = inlinerIface.cg;
589 auto &calls = inlinerIface.calls;
590
591 // A set of dead nodes to remove after inlining.
592 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
593
594 // Collect all of the direct calls within the nodes of the current SCC. We
595 // don't traverse nested callgraph nodes, because they are handled separately
596 // likely within a different SCC.
597 for (CallGraphNode *node : currentSCC) {
598 if (node->isExternal())
599 continue;
600
601 // Don't collect calls if the node is already dead.
602 if (useList.isDead(node)) {
603 deadNodes.insert(X: node);
604 } else {
605 collectCallOps(blocks: *node->getCallableRegion(), sourceNode: node, cg,
606 symbolTable&: inlinerIface.symbolTable, calls,
607 /*traverseNestedCGNodes=*/false);
608 }
609 }
610
611 // When inlining a callee produces new call sites, we want to keep track of
612 // the fact that they were inlined from the callee. This allows us to avoid
613 // infinite inlining.
614 using InlineHistoryT = std::optional<size_t>;
615 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
616 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
617
618 LLVM_DEBUG({
619 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
620 for (unsigned i = 0, e = calls.size(); i < e; ++i)
621 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
622 llvm::dbgs() << "}\n";
623 });
624
625 // Try to inline each of the call operations. Don't cache the end iterator
626 // here as more calls may be added during inlining.
627 bool inlinedAnyCalls = false;
628 for (unsigned i = 0; i < calls.size(); ++i) {
629 if (deadNodes.contains(key: calls[i].sourceNode))
630 continue;
631 ResolvedCall it = calls[i];
632
633 InlineHistoryT inlineHistoryID = callHistory[i];
634 bool inHistory =
635 inlineHistoryIncludes(node: it.targetNode, inlineHistoryID, inlineHistory);
636 bool doInline = !inHistory && shouldInline(resolvedCall&: it);
637 CallOpInterface call = it.call;
638 LLVM_DEBUG({
639 if (doInline)
640 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
641 else
642 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
643 });
644 if (!doInline)
645 continue;
646
647 unsigned prevSize = calls.size();
648 Region *targetRegion = it.targetNode->getCallableRegion();
649
650 // If this is the last call to the target node and the node is discardable,
651 // then inline it in-place and delete the node if successful.
652 bool inlineInPlace = useList.hasOneUseAndDiscardable(node: it.targetNode);
653
654 LogicalResult inlineResult =
655 inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
656 cast<CallableOpInterface>(targetRegion->getParentOp()),
657 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
658 if (failed(Result: inlineResult)) {
659 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
660 continue;
661 }
662 inlinedAnyCalls = true;
663
664 // Create a inline history entry for this inlined call, so that we remember
665 // that new callsites came about due to inlining Callee.
666 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
667 inlineHistory.push_back(Elt: std::make_pair(x&: it.targetNode, y&: inlineHistoryID));
668
669 auto historyToString = [](InlineHistoryT h) {
670 return h.has_value() ? std::to_string(val: *h) : "root";
671 };
672 (void)historyToString;
673 LLVM_DEBUG(llvm::dbgs()
674 << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
675 << getNodeName(call) << ", " << historyToString(inlineHistoryID)
676 << "]\n");
677
678 for (unsigned k = prevSize; k != calls.size(); ++k) {
679 callHistory.push_back(x: newInlineHistoryID);
680 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
681 << "}\n with historyID = " << newInlineHistoryID
682 << ", added due to inlining of\n call {" << call
683 << "}\n with historyID = "
684 << historyToString(inlineHistoryID) << "\n");
685 }
686
687 // If the inlining was successful, Merge the new uses into the source node.
688 useList.dropCallUses(userNode: it.sourceNode, callOp: call.getOperation(), cg);
689 useList.mergeUsesAfterInlining(lhs: it.targetNode, rhs: it.sourceNode);
690
691 // then erase the call.
692 call.erase();
693
694 // If we inlined in place, mark the node for deletion.
695 if (inlineInPlace) {
696 useList.eraseNode(node: it.targetNode);
697 deadNodes.insert(X: it.targetNode);
698 }
699 }
700
701 for (CallGraphNode *node : deadNodes) {
702 currentSCC.remove(node);
703 inlinerIface.markForDeletion(node);
704 }
705 calls.clear();
706 return success(IsSuccess: inlinedAnyCalls);
707}
708
709/// Returns true if the given call should be inlined.
710bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
711 // Don't allow inlining terminator calls. We currently don't support this
712 // case.
713 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
714 return false;
715
716 // Don't allow inlining if the target is a self-recursive function.
717 // Don't allow inlining if the call graph is like A->B->A.
718 if (llvm::count_if(Range&: *resolvedCall.targetNode,
719 P: [&](CallGraphNode::Edge const &edge) -> bool {
720 return edge.getTarget() == resolvedCall.targetNode ||
721 edge.getTarget() == resolvedCall.sourceNode;
722 }) > 0)
723 return false;
724
725 // Don't allow inlining if the target is an ancestor of the call. This
726 // prevents inlining recursively.
727 Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
728 if (callableRegion->isAncestor(other: resolvedCall.call->getParentRegion()))
729 return false;
730
731 // Don't allow inlining if the callee has multiple blocks (unstructured
732 // control flow) but we cannot be sure that the caller region supports that.
733 if (!inliner.config.getCanHandleMultipleBlocks()) {
734 bool calleeHasMultipleBlocks =
735 llvm::hasNItemsOrMore(C&: *callableRegion, /*N=*/2);
736 // If both parent ops have the same type, it is safe to inline. Otherwise,
737 // decide based on whether the op has the SingleBlock trait or not.
738 // Note: This check does currently not account for
739 // SizedRegion/MaxSizedRegion.
740 auto callerRegionSupportsMultipleBlocks = [&]() {
741 return callableRegion->getParentOp()->getName() ==
742 resolvedCall.call->getParentOp()->getName() ||
743 !resolvedCall.call->getParentOp()
744 ->mightHaveTrait<OpTrait::SingleBlock>();
745 };
746 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
747 return false;
748 }
749
750 if (!inliner.isProfitableToInline(resolvedCall))
751 return false;
752
753 // Otherwise, inline.
754 return true;
755}
756
757LogicalResult Inliner::doInlining() {
758 Impl impl(*this);
759 auto *context = op->getContext();
760 // Run the inline transform in post-order over the SCCs in the callgraph.
761 SymbolTableCollection symbolTable;
762 // FIXME: some clean-up can be done for the arguments
763 // of the Impl's methods, if the inlinerIface and useList
764 // become the states of the Impl.
765 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
766 CGUseList useList(op, cg, symbolTable);
767 LogicalResult result = runTransformOnCGSCCs(cg, sccTransformer: [&](CallGraphSCC &scc) {
768 return impl.inlineSCC(inlinerIface, useList, currentSCC&: scc, context);
769 });
770 if (failed(Result: result))
771 return result;
772
773 // After inlining, make sure to erase any callables proven to be dead.
774 inlinerIface.eraseDeadCallables();
775 return success();
776}
777} // namespace mlir
778

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Transforms/Utils/Inliner.cpp