1//===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10#define LLVM_ANALYSIS_MLINLINEADVISOR_H
11
12#include "llvm/Analysis/FunctionPropertiesAnalysis.h"
13#include "llvm/Analysis/InlineAdvisor.h"
14#include "llvm/Analysis/LazyCallGraph.h"
15#include "llvm/Analysis/MLModelRunner.h"
16#include "llvm/IR/PassManager.h"
17
18#include <deque>
19#include <map>
20#include <memory>
21#include <optional>
22
23namespace llvm {
24class DiagnosticInfoOptimizationBase;
25class Module;
26class MLInlineAdvice;
27
28class MLInlineAdvisor : public InlineAdvisor {
29public:
30 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
31 std::unique_ptr<MLModelRunner> ModelRunner,
32 std::function<bool(CallBase &)> GetDefaultAdvice);
33
34 virtual ~MLInlineAdvisor() = default;
35
36 void onPassEntry(LazyCallGraph::SCC *SCC) override;
37 void onPassExit(LazyCallGraph::SCC *SCC) override;
38
39 int64_t getIRSize(Function &F) const {
40 return getCachedFPI(F).TotalInstructionCount;
41 }
42 void onSuccessfulInlining(const MLInlineAdvice &Advice,
43 bool CalleeWasDeleted);
44
45 bool isForcedToStop() const { return ForceStop; }
46 int64_t getLocalCalls(Function &F);
47 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
48 FunctionPropertiesInfo &getCachedFPI(Function &) const;
49
50protected:
51 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
52
53 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
54 bool Advice) override;
55
56 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
57
58 virtual std::unique_ptr<MLInlineAdvice>
59 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
60
61 // Get the initial 'level' of the function, or 0 if the function has been
62 // introduced afterwards.
63 // TODO: should we keep this updated?
64 unsigned getInitialFunctionLevel(const Function &F) const;
65
66 std::unique_ptr<MLModelRunner> ModelRunner;
67 std::function<bool(CallBase &)> GetDefaultAdvice;
68
69private:
70 int64_t getModuleIRSize() const;
71 std::unique_ptr<InlineAdvice>
72 getSkipAdviceIfUnreachableCallsite(CallBase &CB);
73 void print(raw_ostream &OS) const override;
74
75 // Using std::map to benefit from its iterator / reference non-invalidating
76 // semantics, which make it easy to use `getCachedFPI` results from multiple
77 // calls without needing to copy to avoid invalidation effects.
78 mutable std::map<const Function *, FunctionPropertiesInfo> FPICache;
79
80 LazyCallGraph &CG;
81
82 int64_t NodeCount = 0;
83 int64_t EdgeCount = 0;
84 int64_t EdgesOfLastSeenNodes = 0;
85
86 std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
87 const int32_t InitialIRSize = 0;
88 int32_t CurrentIRSize = 0;
89 llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC;
90 DenseSet<const LazyCallGraph::Node *> AllNodes;
91 bool ForceStop = false;
92};
93
94/// InlineAdvice that tracks changes post inlining. For that reason, it only
95/// overrides the "successful inlining" extension points.
96class MLInlineAdvice : public InlineAdvice {
97public:
98 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
99 OptimizationRemarkEmitter &ORE, bool Recommendation);
100 virtual ~MLInlineAdvice() = default;
101
102 void recordInliningImpl() override;
103 void recordInliningWithCalleeDeletedImpl() override;
104 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
105 void recordUnattemptedInliningImpl() override;
106
107 Function *getCaller() const { return Caller; }
108 Function *getCallee() const { return Callee; }
109
110 const int64_t CallerIRSize;
111 const int64_t CalleeIRSize;
112 const int64_t CallerAndCalleeEdges;
113 void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const;
114
115private:
116 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
117 MLInlineAdvisor *getAdvisor() const {
118 return static_cast<MLInlineAdvisor *>(Advisor);
119 };
120 // Make a copy of the FPI of the caller right before inlining. If inlining
121 // fails, we can just update the cache with that value.
122 const FunctionPropertiesInfo PreInlineCallerFPI;
123 std::optional<FunctionPropertiesUpdater> FPU;
124};
125
126} // namespace llvm
127
128#endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
129

source code of llvm/include/llvm/Analysis/MLInlineAdvisor.h