1//======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12#include "mlir/Interfaces/CallInterfaces.h"
13#include "mlir/Interfaces/ControlFlowInterfaces.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15#include "mlir/Interfaces/ViewLikeInterface.h"
16#include "llvm/ADT/SetOperations.h"
17
18using namespace mlir;
19using namespace mlir::bufferization;
20
21//===----------------------------------------------------------------------===//
22// BufferViewFlowAnalysis
23//===----------------------------------------------------------------------===//
24
25/// Constructs a new alias analysis using the op provided.
26BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
27
28static BufferViewFlowAnalysis::ValueSetT
29resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
30 BufferViewFlowAnalysis::ValueSetT result;
31 SmallVector<Value, 8> queue;
32 queue.push_back(Elt: value);
33 while (!queue.empty()) {
34 Value currentValue = queue.pop_back_val();
35 if (result.insert(Ptr: currentValue).second) {
36 auto it = map.find(Val: currentValue);
37 if (it != map.end()) {
38 for (Value aliasValue : it->second)
39 queue.push_back(Elt: aliasValue);
40 }
41 }
42 }
43 return result;
44}
45
46/// Find all immediate and indirect dependent buffers this value could
47/// potentially have. Note that the resulting set will also contain the value
48/// provided as it is a dependent alias of itself.
49BufferViewFlowAnalysis::ValueSetT
50BufferViewFlowAnalysis::resolve(Value rootValue) const {
51 return resolveValues(map: dependencies, value: rootValue);
52}
53
54BufferViewFlowAnalysis::ValueSetT
55BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
56 return resolveValues(map: reverseDependencies, value: rootValue);
57}
58
59/// Removes the given values from all alias sets.
60void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
61 for (auto &entry : dependencies)
62 llvm::set_subtract(S1&: entry.second, S2: aliasValues);
63}
64
65void BufferViewFlowAnalysis::rename(Value from, Value to) {
66 dependencies[to] = dependencies[from];
67 dependencies.erase(Val: from);
68
69 for (auto &[_, value] : dependencies) {
70 if (value.contains(Ptr: from)) {
71 value.insert(Ptr: to);
72 value.erase(Ptr: from);
73 }
74 }
75}
76
77/// This function constructs a mapping from values to its immediate
78/// dependencies. It iterates over all blocks, gets their predecessors,
79/// determines the values that will be passed to the corresponding block
80/// arguments and inserts them into the underlying map. Furthermore, it wires
81/// successor regions and branch-like return operations from nested regions.
82void BufferViewFlowAnalysis::build(Operation *op) {
83 // Registers all dependencies of the given values.
84 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
85 for (auto [value, dep] : llvm::zip_equal(t&: values, u&: dependencies)) {
86 this->dependencies[value].insert(Ptr: dep);
87 this->reverseDependencies[dep].insert(Ptr: value);
88 }
89 };
90
91 // Mark all buffer results and buffer region entry block arguments of the
92 // given op as terminals.
93 auto populateTerminalValues = [&](Operation *op) {
94 for (Value v : op->getResults())
95 if (isa<BaseMemRefType>(Val: v.getType()))
96 this->terminals.insert(V: v);
97 for (Region &r : op->getRegions())
98 for (BlockArgument v : r.getArguments())
99 if (isa<BaseMemRefType>(Val: v.getType()))
100 this->terminals.insert(V: v);
101 };
102
103 op->walk(callback: [&](Operation *op) {
104 // Query BufferViewFlowOpInterface. If the op does not implement that
105 // interface, try to infer the dependencies from other interfaces that the
106 // op may implement.
107 if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(Val: op)) {
108 bufferViewFlowOp.populateDependencies(registerDependenciesFn: registerDependencies);
109 for (Value v : op->getResults())
110 if (isa<BaseMemRefType>(Val: v.getType()) &&
111 bufferViewFlowOp.mayBeTerminalBuffer(value: v))
112 this->terminals.insert(V: v);
113 for (Region &r : op->getRegions())
114 for (BlockArgument v : r.getArguments())
115 if (isa<BaseMemRefType>(Val: v.getType()) &&
116 bufferViewFlowOp.mayBeTerminalBuffer(value: v))
117 this->terminals.insert(V: v);
118 return WalkResult::advance();
119 }
120
121 // Add additional dependencies created by view changes to the alias list.
122 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(Val: op)) {
123 registerDependencies(viewInterface.getViewSource(),
124 viewInterface->getResult(idx: 0));
125 return WalkResult::advance();
126 }
127
128 if (auto branchInterface = dyn_cast<BranchOpInterface>(Val: op)) {
129 // Query all branch interfaces to link block argument dependencies.
130 Block *parentBlock = branchInterface->getBlock();
131 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
132 it != e; ++it) {
133 // Query the branch op interface to get the successor operands.
134 auto successorOperands =
135 branchInterface.getSuccessorOperands(index: it.getIndex());
136 // Build the actual mapping of values to their immediate dependencies.
137 registerDependencies(successorOperands.getForwardedOperands(),
138 (*it)->getArguments().drop_front(
139 N: successorOperands.getProducedOperandCount()));
140 }
141 return WalkResult::advance();
142 }
143
144 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(Val: op)) {
145 // Query the RegionBranchOpInterface to find potential successor regions.
146 // Extract all entry regions and wire all initial entry successor inputs.
147 SmallVector<RegionSuccessor, 2> entrySuccessors;
148 regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
149 regions&: entrySuccessors);
150 for (RegionSuccessor &entrySuccessor : entrySuccessors) {
151 // Wire the entry region's successor arguments with the initial
152 // successor inputs.
153 registerDependencies(
154 regionInterface.getEntrySuccessorOperands(point: entrySuccessor),
155 entrySuccessor.getSuccessorInputs());
156 }
157
158 // Wire flow between regions and from region exits.
159 for (Region &region : regionInterface->getRegions()) {
160 // Iterate over all successor region entries that are reachable from the
161 // current region.
162 SmallVector<RegionSuccessor, 2> successorRegions;
163 regionInterface.getSuccessorRegions(point: region, regions&: successorRegions);
164 for (RegionSuccessor &successorRegion : successorRegions) {
165 // Iterate over all immediate terminator operations and wire the
166 // successor inputs with the successor operands of each terminator.
167 for (Block &block : region)
168 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
169 Val: block.getTerminator()))
170 registerDependencies(
171 terminator.getSuccessorOperands(point: successorRegion),
172 successorRegion.getSuccessorInputs());
173 }
174 }
175
176 return WalkResult::advance();
177 }
178
179 // Region terminators are handled together with RegionBranchOpInterface.
180 if (isa<RegionBranchTerminatorOpInterface>(Val: op))
181 return WalkResult::advance();
182
183 if (isa<CallOpInterface>(Val: op)) {
184 // This is an intra-function analysis. We have no information about other
185 // functions. Conservatively assume that each operand may alias with each
186 // result. Also mark the results are terminals because the function could
187 // return newly allocated buffers.
188 populateTerminalValues(op);
189 for (Value operand : op->getOperands())
190 for (Value result : op->getResults())
191 registerDependencies({operand}, {result});
192 return WalkResult::advance();
193 }
194
195 // We have no information about unknown ops.
196 populateTerminalValues(op);
197
198 return WalkResult::advance();
199 });
200}
201
202bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
203 assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
204 return terminals.contains(V: value);
205}
206
207//===----------------------------------------------------------------------===//
208// BufferOriginAnalysis
209//===----------------------------------------------------------------------===//
210
211/// Return "true" if the given value is the result of a memory allocation.
212static bool hasAllocateSideEffect(Value v) {
213 Operation *op = v.getDefiningOp();
214 if (!op)
215 return false;
216 return hasEffect<MemoryEffects::Allocate>(op, value: v);
217}
218
219/// Return "true" if the given value is a function block argument.
220static bool isFunctionArgument(Value v) {
221 auto bbArg = dyn_cast<BlockArgument>(Val&: v);
222 if (!bbArg)
223 return false;
224 Block *b = bbArg.getOwner();
225 auto funcOp = dyn_cast<FunctionOpInterface>(Val: b->getParentOp());
226 if (!funcOp)
227 return false;
228 return bbArg.getOwner() == &funcOp.getFunctionBody().front();
229}
230
231/// Given a memref value, return the "base" value by skipping over all
232/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
233static Value getViewBase(Value value) {
234 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
235 value = viewLikeOp.getViewSource();
236 return value;
237}
238
239BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
240
241std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
242 assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
243 assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
244
245 // Skip over all view-like ops.
246 v1 = getViewBase(value: v1);
247 v2 = getViewBase(value: v2);
248
249 // Fast path: If both buffers are the same SSA value, we can be sure that
250 // they originate from the same allocation.
251 if (v1 == v2)
252 return true;
253
254 // Compute the SSA values from which the buffers `v1` and `v2` originate.
255 SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(rootValue: v1);
256 SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(rootValue: v2);
257
258 // Originating buffers are "terminal" if they could not be traced back any
259 // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
260 // - function block arguments
261 // - values defined by allocation ops such as "memref.alloc"
262 // - values defined by ops that are unknown to the buffer view flow analysis
263 // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
264 SmallPtrSet<Value, 16> terminal1, terminal2;
265
266 // While gathering terminal buffers, keep track of whether all terminal
267 // buffers are newly allocated buffer or function entry arguments.
268 bool allAllocs1 = true, allAllocs2 = true;
269 bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
270
271 // Helper function that gathers terminal buffers among `origin`.
272 auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
273 SmallPtrSet<Value, 16> &terminal,
274 bool &allAllocs,
275 bool &allAllocsOrFuncEntryArgs) {
276 for (Value v : origin) {
277 if (isa<BaseMemRefType>(Val: v.getType()) && analysis.mayBeTerminalBuffer(value: v)) {
278 terminal.insert(Ptr: v);
279 allAllocs &= hasAllocateSideEffect(v);
280 allAllocsOrFuncEntryArgs &=
281 isFunctionArgument(v) || hasAllocateSideEffect(v);
282 }
283 }
284 assert(!terminal.empty() && "expected non-empty terminal set");
285 };
286
287 // Gather terminal buffers for `v1` and `v2`.
288 gatherTerminalBuffers(origin1, terminal1, allAllocs1,
289 allAllocsOrFuncEntryArgs1);
290 gatherTerminalBuffers(origin2, terminal2, allAllocs2,
291 allAllocsOrFuncEntryArgs2);
292
293 // If both `v1` and `v2` have a single matching terminal buffer, they are
294 // guaranteed to originate from the same buffer allocation.
295 if (llvm::hasSingleElement(C&: terminal1) && llvm::hasSingleElement(C&: terminal2) &&
296 *terminal1.begin() == *terminal2.begin())
297 return true;
298
299 // At least one of the two values has multiple terminals.
300
301 // Check if there is overlap between the terminal buffers of `v1` and `v2`.
302 bool distinctTerminalSets = true;
303 for (Value v : terminal1)
304 distinctTerminalSets &= !terminal2.contains(Ptr: v);
305 // If there is overlap between the terminal buffers of `v1` and `v2`, we
306 // cannot make an accurate decision without further analysis.
307 if (!distinctTerminalSets)
308 return std::nullopt;
309
310 // If `v1` originates from only allocs, and `v2` is guaranteed to originate
311 // from different allocations (that is guaranteed if `v2` originates from
312 // only distinct allocs or function entry arguments), we can be sure that
313 // `v1` and `v2` originate from different allocations. The same argument can
314 // be made when swapping `v1` and `v2`.
315 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
316 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
317 if (isolatedAlloc1 || isolatedAlloc2)
318 return false;
319
320 // Otherwise: We do not know whether `v1` and `v2` originate from the same
321 // allocation or not.
322 // TODO: Function arguments are currently handled conservatively. We assume
323 // that they could be the same allocation.
324 // TODO: Terminals other than allocations and function arguments are
325 // currently handled conservatively. We assume that they could be the same
326 // allocation. E.g., we currently return "nullopt" for values that originate
327 // from different "memref.get_global" ops (with different symbols).
328 return std::nullopt;
329}
330

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp