1//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains interfaces and analyses for defining a nested callgraph.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/CallGraph.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/Interfaces/CallInterfaces.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/SCCIterator.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/Support/raw_ostream.h"
22#include <cassert>
23#include <memory>
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// CallGraphNode
29//===----------------------------------------------------------------------===//
30
31/// Returns true if this node refers to the indirect/external node.
32bool CallGraphNode::isExternal() const { return !callableRegion; }
33
34/// Return the callable region this node represents. This can only be called
35/// on non-external nodes.
36Region *CallGraphNode::getCallableRegion() const {
37 assert(!isExternal() && "the external node has no callable region");
38 return callableRegion;
39}
40
41/// Adds an reference edge to the given node. This is only valid on the
42/// external node.
43void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
44 assert(isExternal() && "abstract edges are only valid on external nodes");
45 addEdge(node, kind: Edge::Kind::Abstract);
46}
47
48/// Add an outgoing call edge from this node.
49void CallGraphNode::addCallEdge(CallGraphNode *node) {
50 addEdge(node, kind: Edge::Kind::Call);
51}
52
53/// Adds a reference edge to the given child node.
54void CallGraphNode::addChildEdge(CallGraphNode *child) {
55 addEdge(node: child, kind: Edge::Kind::Child);
56}
57
58/// Returns true if this node has any child edges.
59bool CallGraphNode::hasChildren() const {
60 return llvm::any_of(Range: edges, P: [](const Edge &edge) { return edge.isChild(); });
61}
62
63/// Add an edge to 'node' with the given kind.
64void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
65 edges.insert(X: {node, kind});
66}
67
68//===----------------------------------------------------------------------===//
69// CallGraph
70//===----------------------------------------------------------------------===//
71
72/// Recursively compute the callgraph edges for the given operation. Computed
73/// edges are placed into the given callgraph object.
74static void computeCallGraph(Operation *op, CallGraph &cg,
75 SymbolTableCollection &symbolTable,
76 CallGraphNode *parentNode, bool resolveCalls) {
77 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
78 // If there is no parent node, we ignore this operation. Even if this
79 // operation was a call, there would be no callgraph node to attribute it
80 // to.
81 if (resolveCalls && parentNode)
82 parentNode->addCallEdge(node: cg.resolveCallable(call: call, symbolTable));
83 return;
84 }
85
86 // Compute the callgraph nodes and edges for each of the nested operations.
87 if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
88 if (auto *callableRegion = callable.getCallableRegion())
89 parentNode = cg.getOrAddNode(region: callableRegion, parentNode);
90 else
91 return;
92 }
93
94 for (Region &region : op->getRegions())
95 for (Operation &nested : region.getOps())
96 computeCallGraph(op: &nested, cg, symbolTable, parentNode, resolveCalls);
97}
98
99CallGraph::CallGraph(Operation *op)
100 : externalCallerNode(/*callableRegion=*/nullptr),
101 unknownCalleeNode(/*callableRegion=*/nullptr) {
102 // Make two passes over the graph, one to compute the callables and one to
103 // resolve the calls. We split these up as we may have nested callable objects
104 // that need to be reserved before the calls.
105 SymbolTableCollection symbolTable;
106 computeCallGraph(op, cg&: *this, symbolTable, /*parentNode=*/nullptr,
107 /*resolveCalls=*/false);
108 computeCallGraph(op, cg&: *this, symbolTable, /*parentNode=*/nullptr,
109 /*resolveCalls=*/true);
110}
111
112/// Get or add a call graph node for the given region.
113CallGraphNode *CallGraph::getOrAddNode(Region *region,
114 CallGraphNode *parentNode) {
115 assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
116 "expected parent operation to be callable");
117 std::unique_ptr<CallGraphNode> &node = nodes[region];
118 if (!node) {
119 node.reset(p: new CallGraphNode(region));
120
121 // Add this node to the given parent node if necessary.
122 if (parentNode) {
123 parentNode->addChildEdge(child: node.get());
124 } else {
125 // Otherwise, connect all callable nodes to the external node, this allows
126 // for conservatively including all callable nodes within the graph.
127 // FIXME This isn't correct, this is only necessary for callable nodes
128 // that *could* be called from external sources. This requires extending
129 // the interface for callables to check if they may be referenced
130 // externally.
131 externalCallerNode.addAbstractEdge(node: node.get());
132 }
133 }
134 return node.get();
135}
136
137/// Lookup a call graph node for the given region, or nullptr if none is
138/// registered.
139CallGraphNode *CallGraph::lookupNode(Region *region) const {
140 const auto *it = nodes.find(Key: region);
141 return it == nodes.end() ? nullptr : it->second.get();
142}
143
144/// Resolve the callable for given callee to a node in the callgraph, or the
145/// unknown callee node if a valid node was not resolved.
146CallGraphNode *
147CallGraph::resolveCallable(CallOpInterface call,
148 SymbolTableCollection &symbolTable) const {
149 Operation *callable = call.resolveCallable(&symbolTable);
150 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
151 if (auto *node = lookupNode(callableOp.getCallableRegion()))
152 return node;
153
154 return getUnknownCalleeNode();
155}
156
157/// Erase the given node from the callgraph.
158void CallGraph::eraseNode(CallGraphNode *node) {
159 // Erase any children of this node first.
160 if (node->hasChildren()) {
161 for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(Range&: *node))
162 if (edge.isChild())
163 eraseNode(node: edge.getTarget());
164 }
165 // Erase any edges to this node from any other nodes.
166 for (auto &it : nodes) {
167 it.second->edges.remove_if(P: [node](const CallGraphNode::Edge &edge) {
168 return edge.getTarget() == node;
169 });
170 }
171 nodes.erase(Key: node->getCallableRegion());
172}
173
174//===----------------------------------------------------------------------===//
175// Printing
176
177/// Dump the graph in a human readable format.
178void CallGraph::dump() const { print(os&: llvm::errs()); }
179void CallGraph::print(raw_ostream &os) const {
180 os << "// ---- CallGraph ----\n";
181
182 // Functor used to output the name for the given node.
183 auto emitNodeName = [&](const CallGraphNode *node) {
184 if (node == getExternalCallerNode()) {
185 os << "<External-Caller-Node>";
186 return;
187 }
188 if (node == getUnknownCalleeNode()) {
189 os << "<Unknown-Callee-Node>";
190 return;
191 }
192
193 auto *callableRegion = node->getCallableRegion();
194 auto *parentOp = callableRegion->getParentOp();
195 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
196 << callableRegion->getRegionNumber();
197 auto attrs = parentOp->getAttrDictionary();
198 if (!attrs.empty())
199 os << " : " << attrs;
200 };
201
202 for (auto &nodeIt : nodes) {
203 const CallGraphNode *node = nodeIt.second.get();
204
205 // Dump the header for this node.
206 os << "// - Node : ";
207 emitNodeName(node);
208 os << "\n";
209
210 // Emit each of the edges.
211 for (auto &edge : *node) {
212 os << "// -- ";
213 if (edge.isCall())
214 os << "Call";
215 else if (edge.isChild())
216 os << "Child";
217
218 os << "-Edge : ";
219 emitNodeName(edge.getTarget());
220 os << "\n";
221 }
222 os << "//\n";
223 }
224
225 os << "// -- SCCs --\n";
226
227 for (auto &scc : make_range(x: llvm::scc_begin(G: this), y: llvm::scc_end(G: this))) {
228 os << "// - SCC : \n";
229 for (auto &node : scc) {
230 os << "// -- Node :";
231 emitNodeName(node);
232 os << "\n";
233 }
234 os << "\n";
235 }
236
237 os << "// -------------------\n";
238}
239

source code of mlir/lib/Analysis/CallGraph.cpp