1//===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains an analysis for computing the multi-level callgraph from a
10// given top-level operation. This nodes within this callgraph are defined by
11// the `CallOpInterface` and `CallableOpInterface` operation interfaces defined
12// in CallInterface.td.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef MLIR_ANALYSIS_CALLGRAPH_H
17#define MLIR_ANALYSIS_CALLGRAPH_H
18
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/GraphTraits.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/ADT/PointerIntPair.h"
23#include "llvm/ADT/SetVector.h"
24
25namespace mlir {
26class CallOpInterface;
27struct CallInterfaceCallable;
28class Operation;
29class Region;
30class SymbolTableCollection;
31
32//===----------------------------------------------------------------------===//
33// CallGraphNode
34//===----------------------------------------------------------------------===//
35
36/// This class represents a single callable in the callgraph. Aside from the
37/// external node, each node represents a callable node in the graph and
38/// contains a valid corresponding Region. The external node is a virtual node
39/// used to represent external edges into, and out of, the callgraph.
40class CallGraphNode {
41public:
42 /// This class represents a directed edge between two nodes in the callgraph.
43 class Edge {
44 enum class Kind {
45 // An 'Abstract' edge represents an opaque, non-operation, reference
46 // between this node and the target. Edges of this type are only valid
47 // from the external node, as there is no valid connection to an operation
48 // in the module.
49 Abstract,
50
51 // A 'Call' edge represents a direct reference to the target node via a
52 // call-like operation within the callable region of this node.
53 Call,
54
55 // A 'Child' edge is used when the region of target node is defined inside
56 // of the callable region of this node. This means that the region of this
57 // node is an ancestor of the region for the target node. As such, this
58 // edge cannot be used on the 'external' node.
59 Child,
60 };
61
62 public:
63 /// Returns true if this edge represents an `Abstract` edge.
64 bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; }
65
66 /// Returns true if this edge represents a `Call` edge.
67 bool isCall() const { return targetAndKind.getInt() == Kind::Call; }
68
69 /// Returns true if this edge represents a `Child` edge.
70 bool isChild() const { return targetAndKind.getInt() == Kind::Child; }
71
72 /// Returns the target node for this edge.
73 CallGraphNode *getTarget() const { return targetAndKind.getPointer(); }
74
75 bool operator==(const Edge &edge) const {
76 return targetAndKind == edge.targetAndKind;
77 }
78
79 private:
80 Edge(CallGraphNode *node, Kind kind) : targetAndKind(node, kind) {}
81 explicit Edge(llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind)
82 : targetAndKind(targetAndKind) {}
83
84 /// The target node of this edge, as well as the edge kind.
85 llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind;
86
87 // Provide access to the constructor and Kind.
88 friend class CallGraphNode;
89 };
90
91 /// Returns true if this node is an external node.
92 bool isExternal() const;
93
94 /// Returns the callable region this node represents. This can only be called
95 /// on non-external nodes.
96 Region *getCallableRegion() const;
97
98 /// Adds an abstract reference edge to the given node. An abstract edge does
99 /// not come from any observable operations, so this is only valid on the
100 /// external node.
101 void addAbstractEdge(CallGraphNode *node);
102
103 /// Add an outgoing call edge from this node.
104 void addCallEdge(CallGraphNode *node);
105
106 /// Adds a reference edge to the given child node.
107 void addChildEdge(CallGraphNode *child);
108
109 /// Iterator over the outgoing edges of this node.
110 using iterator = SmallVectorImpl<Edge>::const_iterator;
111 iterator begin() const { return edges.begin(); }
112 iterator end() const { return edges.end(); }
113
114 /// Returns true if this node has any child edges.
115 bool hasChildren() const;
116
117private:
118 /// DenseMap info for callgraph edges.
119 struct EdgeKeyInfo {
120 using BaseInfo =
121 DenseMapInfo<llvm::PointerIntPair<CallGraphNode *, 2, Edge::Kind>>;
122
123 static Edge getEmptyKey() { return Edge(BaseInfo::getEmptyKey()); }
124 static Edge getTombstoneKey() { return Edge(BaseInfo::getTombstoneKey()); }
125 static unsigned getHashValue(const Edge &edge) {
126 return BaseInfo::getHashValue(V: edge.targetAndKind);
127 }
128 static bool isEqual(const Edge &lhs, const Edge &rhs) { return lhs == rhs; }
129 };
130
131 CallGraphNode(Region *callableRegion) : callableRegion(callableRegion) {}
132
133 /// Add an edge to 'node' with the given kind.
134 void addEdge(CallGraphNode *node, Edge::Kind kind);
135
136 /// The callable region defines the boundary of the call graph node. This is
137 /// the region referenced by 'call' operations. This is at a per-region
138 /// boundary as operations may define multiple callable regions.
139 Region *callableRegion;
140
141 /// A set of out-going edges from this node to other nodes in the graph.
142 SetVector<Edge, SmallVector<Edge, 4>,
143 llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>>
144 edges;
145
146 // Provide access to private methods.
147 friend class CallGraph;
148};
149
150//===----------------------------------------------------------------------===//
151// CallGraph
152//===----------------------------------------------------------------------===//
153
154class CallGraph {
155 using NodeMapT = llvm::MapVector<Region *, std::unique_ptr<CallGraphNode>>;
156
157 /// This class represents an iterator over the internal call graph nodes. This
158 /// class unwraps the map iterator to access the raw node.
159 class NodeIterator final
160 : public llvm::mapped_iterator<
161 NodeMapT::const_iterator,
162 CallGraphNode *(*)(const NodeMapT::value_type &)> {
163 static CallGraphNode *unwrap(const NodeMapT::value_type &value) {
164 return value.second.get();
165 }
166
167 public:
168 /// Initializes the result type iterator to the specified result iterator.
169 NodeIterator(NodeMapT::const_iterator it)
170 : llvm::mapped_iterator<
171 NodeMapT::const_iterator,
172 CallGraphNode *(*)(const NodeMapT::value_type &)>(it, &unwrap) {}
173 };
174
175public:
176 CallGraph(Operation *op);
177
178 /// Get or add a call graph node for the given region. `parentNode`
179 /// corresponds to the direct node in the callgraph that contains the parent
180 /// operation of `region`, or nullptr if there is no parent node.
181 CallGraphNode *getOrAddNode(Region *region, CallGraphNode *parentNode);
182
183 /// Lookup a call graph node for the given region, or nullptr if none is
184 /// registered.
185 CallGraphNode *lookupNode(Region *region) const;
186
187 /// Return the callgraph node representing an external caller.
188 CallGraphNode *getExternalCallerNode() const {
189 return const_cast<CallGraphNode *>(&externalCallerNode);
190 }
191
192 /// Return the callgraph node representing an indirect callee.
193 CallGraphNode *getUnknownCalleeNode() const {
194 return const_cast<CallGraphNode *>(&unknownCalleeNode);
195 }
196
197 /// Resolve the callable for given callee to a node in the callgraph, or the
198 /// external node if a valid node was not resolved. The provided symbol table
199 /// is used when resolving calls that reference callables via a symbol
200 /// reference.
201 CallGraphNode *resolveCallable(CallOpInterface call,
202 SymbolTableCollection &symbolTable) const;
203
204 /// Erase the given node from the callgraph.
205 void eraseNode(CallGraphNode *node);
206
207 /// An iterator over the nodes of the graph.
208 using iterator = NodeIterator;
209 iterator begin() const { return nodes.begin(); }
210 iterator end() const { return nodes.end(); }
211
212 /// Dump the graph in a human readable format.
213 void dump() const;
214 void print(raw_ostream &os) const;
215
216private:
217 /// The set of nodes within the callgraph.
218 NodeMapT nodes;
219
220 /// A special node used to indicate an external caller.
221 CallGraphNode externalCallerNode;
222
223 /// A special node used to indicate an unknown callee.
224 CallGraphNode unknownCalleeNode;
225};
226
227} // namespace mlir
228
229namespace llvm {
230// Provide graph traits for traversing call graphs using standard graph
231// traversals.
232template <>
233struct GraphTraits<const mlir::CallGraphNode *> {
234 using NodeRef = mlir::CallGraphNode *;
235 static NodeRef getEntryNode(NodeRef node) { return node; }
236
237 static NodeRef unwrap(const mlir::CallGraphNode::Edge &edge) {
238 return edge.getTarget();
239 }
240
241 // ChildIteratorType/begin/end - Allow iteration over all nodes in the graph.
242 using ChildIteratorType =
243 mapped_iterator<mlir::CallGraphNode::iterator, decltype(&unwrap)>;
244 static ChildIteratorType child_begin(NodeRef node) {
245 return {node->begin(), &unwrap};
246 }
247 static ChildIteratorType child_end(NodeRef node) {
248 return {node->end(), &unwrap};
249 }
250};
251
252template <>
253struct GraphTraits<const mlir::CallGraph *>
254 : public GraphTraits<const mlir::CallGraphNode *> {
255 /// The entry node into the graph is the external node.
256 static NodeRef getEntryNode(const mlir::CallGraph *cg) {
257 return cg->getExternalCallerNode();
258 }
259
260 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
261 using nodes_iterator = mlir::CallGraph::iterator;
262 static nodes_iterator nodes_begin(mlir::CallGraph *cg) { return cg->begin(); }
263 static nodes_iterator nodes_end(mlir::CallGraph *cg) { return cg->end(); }
264};
265} // namespace llvm
266
267#endif // MLIR_ANALYSIS_CALLGRAPH_H
268

source code of mlir/include/mlir/Analysis/CallGraph.h