1 | //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H |
10 | #define LLVM_ANALYSIS_MLINLINEADVISOR_H |
11 | |
12 | #include "llvm/Analysis/FunctionPropertiesAnalysis.h" |
13 | #include "llvm/Analysis/InlineAdvisor.h" |
14 | #include "llvm/Analysis/LazyCallGraph.h" |
15 | #include "llvm/Analysis/MLModelRunner.h" |
16 | #include "llvm/IR/PassManager.h" |
17 | |
18 | #include <deque> |
19 | #include <map> |
20 | #include <memory> |
21 | #include <optional> |
22 | |
23 | namespace llvm { |
24 | class DiagnosticInfoOptimizationBase; |
25 | class Module; |
26 | class MLInlineAdvice; |
27 | |
28 | class MLInlineAdvisor : public InlineAdvisor { |
29 | public: |
30 | MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, |
31 | std::unique_ptr<MLModelRunner> ModelRunner, |
32 | std::function<bool(CallBase &)> GetDefaultAdvice); |
33 | |
34 | virtual ~MLInlineAdvisor() = default; |
35 | |
36 | void onPassEntry(LazyCallGraph::SCC *SCC) override; |
37 | void onPassExit(LazyCallGraph::SCC *SCC) override; |
38 | |
39 | int64_t getIRSize(Function &F) const { |
40 | return getCachedFPI(F).TotalInstructionCount; |
41 | } |
42 | void onSuccessfulInlining(const MLInlineAdvice &Advice, |
43 | bool CalleeWasDeleted); |
44 | |
45 | bool isForcedToStop() const { return ForceStop; } |
46 | int64_t getLocalCalls(Function &F); |
47 | const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } |
48 | FunctionPropertiesInfo &getCachedFPI(Function &) const; |
49 | |
50 | protected: |
51 | std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; |
52 | |
53 | std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, |
54 | bool Advice) override; |
55 | |
56 | virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); |
57 | |
58 | virtual std::unique_ptr<MLInlineAdvice> |
59 | (CallBase &CB, OptimizationRemarkEmitter &ORE); |
60 | |
61 | // Get the initial 'level' of the function, or 0 if the function has been |
62 | // introduced afterwards. |
63 | // TODO: should we keep this updated? |
64 | unsigned getInitialFunctionLevel(const Function &F) const; |
65 | |
66 | std::unique_ptr<MLModelRunner> ModelRunner; |
67 | std::function<bool(CallBase &)> GetDefaultAdvice; |
68 | |
69 | private: |
70 | int64_t getModuleIRSize() const; |
71 | std::unique_ptr<InlineAdvice> |
72 | getSkipAdviceIfUnreachableCallsite(CallBase &CB); |
73 | void print(raw_ostream &OS) const override; |
74 | |
75 | // Using std::map to benefit from its iterator / reference non-invalidating |
76 | // semantics, which make it easy to use `getCachedFPI` results from multiple |
77 | // calls without needing to copy to avoid invalidation effects. |
78 | mutable std::map<const Function *, FunctionPropertiesInfo> FPICache; |
79 | |
80 | LazyCallGraph &CG; |
81 | |
82 | int64_t NodeCount = 0; |
83 | int64_t EdgeCount = 0; |
84 | int64_t EdgesOfLastSeenNodes = 0; |
85 | |
86 | std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; |
87 | const int32_t InitialIRSize = 0; |
88 | int32_t CurrentIRSize = 0; |
89 | llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC; |
90 | DenseSet<const LazyCallGraph::Node *> AllNodes; |
91 | bool ForceStop = false; |
92 | }; |
93 | |
94 | /// InlineAdvice that tracks changes post inlining. For that reason, it only |
95 | /// overrides the "successful inlining" extension points. |
96 | class MLInlineAdvice : public InlineAdvice { |
97 | public: |
98 | (MLInlineAdvisor *Advisor, CallBase &CB, |
99 | OptimizationRemarkEmitter &ORE, bool Recommendation); |
100 | virtual ~MLInlineAdvice() = default; |
101 | |
102 | void recordInliningImpl() override; |
103 | void recordInliningWithCalleeDeletedImpl() override; |
104 | void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; |
105 | void recordUnattemptedInliningImpl() override; |
106 | |
107 | Function *getCaller() const { return Caller; } |
108 | Function *getCallee() const { return Callee; } |
109 | |
110 | const int64_t CallerIRSize; |
111 | const int64_t CalleeIRSize; |
112 | const int64_t CallerAndCalleeEdges; |
113 | void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const; |
114 | |
115 | private: |
116 | void (DiagnosticInfoOptimizationBase &OR); |
117 | MLInlineAdvisor *getAdvisor() const { |
118 | return static_cast<MLInlineAdvisor *>(Advisor); |
119 | }; |
120 | // Make a copy of the FPI of the caller right before inlining. If inlining |
121 | // fails, we can just update the cache with that value. |
122 | const FunctionPropertiesInfo PreInlineCallerFPI; |
123 | std::optional<FunctionPropertiesUpdater> FPU; |
124 | }; |
125 | |
126 | } // namespace llvm |
127 | |
128 | #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H |
129 | |