1//======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12#include "mlir/Interfaces/CallInterfaces.h"
13#include "mlir/Interfaces/ControlFlowInterfaces.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15#include "mlir/Interfaces/ViewLikeInterface.h"
16#include "llvm/ADT/SetOperations.h"
17#include "llvm/ADT/SetVector.h"
18
19using namespace mlir;
20using namespace mlir::bufferization;
21
22//===----------------------------------------------------------------------===//
23// BufferViewFlowAnalysis
24//===----------------------------------------------------------------------===//
25
26/// Constructs a new alias analysis using the op provided.
27BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
28
29static BufferViewFlowAnalysis::ValueSetT
30resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
31 BufferViewFlowAnalysis::ValueSetT result;
32 SmallVector<Value, 8> queue;
33 queue.push_back(Elt: value);
34 while (!queue.empty()) {
35 Value currentValue = queue.pop_back_val();
36 if (result.insert(Ptr: currentValue).second) {
37 auto it = map.find(Val: currentValue);
38 if (it != map.end()) {
39 for (Value aliasValue : it->second)
40 queue.push_back(Elt: aliasValue);
41 }
42 }
43 }
44 return result;
45}
46
47/// Find all immediate and indirect dependent buffers this value could
48/// potentially have. Note that the resulting set will also contain the value
49/// provided as it is a dependent alias of itself.
50BufferViewFlowAnalysis::ValueSetT
51BufferViewFlowAnalysis::resolve(Value rootValue) const {
52 return resolveValues(map: dependencies, value: rootValue);
53}
54
55BufferViewFlowAnalysis::ValueSetT
56BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
57 return resolveValues(map: reverseDependencies, value: rootValue);
58}
59
60/// Removes the given values from all alias sets.
61void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
62 for (auto &entry : dependencies)
63 llvm::set_subtract(S1&: entry.second, S2: aliasValues);
64}
65
66void BufferViewFlowAnalysis::rename(Value from, Value to) {
67 dependencies[to] = dependencies[from];
68 dependencies.erase(Val: from);
69
70 for (auto &[_, value] : dependencies) {
71 if (value.contains(Ptr: from)) {
72 value.insert(Ptr: to);
73 value.erase(Ptr: from);
74 }
75 }
76}
77
78/// This function constructs a mapping from values to its immediate
79/// dependencies. It iterates over all blocks, gets their predecessors,
80/// determines the values that will be passed to the corresponding block
81/// arguments and inserts them into the underlying map. Furthermore, it wires
82/// successor regions and branch-like return operations from nested regions.
83void BufferViewFlowAnalysis::build(Operation *op) {
84 // Registers all dependencies of the given values.
85 auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
86 for (auto [value, dep] : llvm::zip_equal(t&: values, u&: dependencies)) {
87 this->dependencies[value].insert(Ptr: dep);
88 this->reverseDependencies[dep].insert(Ptr: value);
89 }
90 };
91
92 // Mark all buffer results and buffer region entry block arguments of the
93 // given op as terminals.
94 auto populateTerminalValues = [&](Operation *op) {
95 for (Value v : op->getResults())
96 if (isa<BaseMemRefType>(Val: v.getType()))
97 this->terminals.insert(V: v);
98 for (Region &r : op->getRegions())
99 for (BlockArgument v : r.getArguments())
100 if (isa<BaseMemRefType>(Val: v.getType()))
101 this->terminals.insert(V: v);
102 };
103
104 op->walk(callback: [&](Operation *op) {
105 // Query BufferViewFlowOpInterface. If the op does not implement that
106 // interface, try to infer the dependencies from other interfaces that the
107 // op may implement.
108 if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
109 bufferViewFlowOp.populateDependencies(registerDependencies);
110 for (Value v : op->getResults())
111 if (isa<BaseMemRefType>(Val: v.getType()) &&
112 bufferViewFlowOp.mayBeTerminalBuffer(v))
113 this->terminals.insert(V: v);
114 for (Region &r : op->getRegions())
115 for (BlockArgument v : r.getArguments())
116 if (isa<BaseMemRefType>(Val: v.getType()) &&
117 bufferViewFlowOp.mayBeTerminalBuffer(v))
118 this->terminals.insert(V: v);
119 return WalkResult::advance();
120 }
121
122 // Add additional dependencies created by view changes to the alias list.
123 if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
124 registerDependencies(viewInterface.getViewSource(),
125 viewInterface->getResult(0));
126 return WalkResult::advance();
127 }
128
129 if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
130 // Query all branch interfaces to link block argument dependencies.
131 Block *parentBlock = branchInterface->getBlock();
132 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
133 it != e; ++it) {
134 // Query the branch op interface to get the successor operands.
135 auto successorOperands =
136 branchInterface.getSuccessorOperands(it.getIndex());
137 // Build the actual mapping of values to their immediate dependencies.
138 registerDependencies(successorOperands.getForwardedOperands(),
139 (*it)->getArguments().drop_front(
140 successorOperands.getProducedOperandCount()));
141 }
142 return WalkResult::advance();
143 }
144
145 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
146 // Query the RegionBranchOpInterface to find potential successor regions.
147 // Extract all entry regions and wire all initial entry successor inputs.
148 SmallVector<RegionSuccessor, 2> entrySuccessors;
149 regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
150 entrySuccessors);
151 for (RegionSuccessor &entrySuccessor : entrySuccessors) {
152 // Wire the entry region's successor arguments with the initial
153 // successor inputs.
154 registerDependencies(
155 regionInterface.getEntrySuccessorOperands(entrySuccessor),
156 entrySuccessor.getSuccessorInputs());
157 }
158
159 // Wire flow between regions and from region exits.
160 for (Region &region : regionInterface->getRegions()) {
161 // Iterate over all successor region entries that are reachable from the
162 // current region.
163 SmallVector<RegionSuccessor, 2> successorRegions;
164 regionInterface.getSuccessorRegions(region, successorRegions);
165 for (RegionSuccessor &successorRegion : successorRegions) {
166 // Iterate over all immediate terminator operations and wire the
167 // successor inputs with the successor operands of each terminator.
168 for (Block &block : region)
169 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
170 block.getTerminator()))
171 registerDependencies(
172 terminator.getSuccessorOperands(successorRegion),
173 successorRegion.getSuccessorInputs());
174 }
175 }
176
177 return WalkResult::advance();
178 }
179
180 // Region terminators are handled together with RegionBranchOpInterface.
181 if (isa<RegionBranchTerminatorOpInterface>(op))
182 return WalkResult::advance();
183
184 if (isa<CallOpInterface>(op)) {
185 // This is an intra-function analysis. We have no information about other
186 // functions. Conservatively assume that each operand may alias with each
187 // result. Also mark the results are terminals because the function could
188 // return newly allocated buffers.
189 populateTerminalValues(op);
190 for (Value operand : op->getOperands())
191 for (Value result : op->getResults())
192 registerDependencies({operand}, {result});
193 return WalkResult::advance();
194 }
195
196 // We have no information about unknown ops.
197 populateTerminalValues(op);
198
199 return WalkResult::advance();
200 });
201}
202
203bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
204 assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
205 return terminals.contains(V: value);
206}
207
208//===----------------------------------------------------------------------===//
209// BufferOriginAnalysis
210//===----------------------------------------------------------------------===//
211
212/// Return "true" if the given value is the result of a memory allocation.
213static bool hasAllocateSideEffect(Value v) {
214 Operation *op = v.getDefiningOp();
215 if (!op)
216 return false;
217 return hasEffect<MemoryEffects::Allocate>(op, value: v);
218}
219
220/// Return "true" if the given value is a function block argument.
221static bool isFunctionArgument(Value v) {
222 auto bbArg = dyn_cast<BlockArgument>(Val&: v);
223 if (!bbArg)
224 return false;
225 Block *b = bbArg.getOwner();
226 auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
227 if (!funcOp)
228 return false;
229 return bbArg.getOwner() == &funcOp.getFunctionBody().front();
230}
231
232/// Given a memref value, return the "base" value by skipping over all
233/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
234static Value getViewBase(Value value) {
235 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
236 value = viewLikeOp.getViewSource();
237 return value;
238}
239
240BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
241
242std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
243 assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
244 assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
245
246 // Skip over all view-like ops.
247 v1 = getViewBase(value: v1);
248 v2 = getViewBase(value: v2);
249
250 // Fast path: If both buffers are the same SSA value, we can be sure that
251 // they originate from the same allocation.
252 if (v1 == v2)
253 return true;
254
255 // Compute the SSA values from which the buffers `v1` and `v2` originate.
256 SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(rootValue: v1);
257 SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(rootValue: v2);
258
259 // Originating buffers are "terminal" if they could not be traced back any
260 // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
261 // - function block arguments
262 // - values defined by allocation ops such as "memref.alloc"
263 // - values defined by ops that are unknown to the buffer view flow analysis
264 // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
265 SmallPtrSet<Value, 16> terminal1, terminal2;
266
267 // While gathering terminal buffers, keep track of whether all terminal
268 // buffers are newly allocated buffer or function entry arguments.
269 bool allAllocs1 = true, allAllocs2 = true;
270 bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
271
272 // Helper function that gathers terminal buffers among `origin`.
273 auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
274 SmallPtrSet<Value, 16> &terminal,
275 bool &allAllocs,
276 bool &allAllocsOrFuncEntryArgs) {
277 for (Value v : origin) {
278 if (isa<BaseMemRefType>(Val: v.getType()) && analysis.mayBeTerminalBuffer(value: v)) {
279 terminal.insert(Ptr: v);
280 allAllocs &= hasAllocateSideEffect(v);
281 allAllocsOrFuncEntryArgs &=
282 isFunctionArgument(v) || hasAllocateSideEffect(v);
283 }
284 }
285 assert(!terminal.empty() && "expected non-empty terminal set");
286 };
287
288 // Gather terminal buffers for `v1` and `v2`.
289 gatherTerminalBuffers(origin1, terminal1, allAllocs1,
290 allAllocsOrFuncEntryArgs1);
291 gatherTerminalBuffers(origin2, terminal2, allAllocs2,
292 allAllocsOrFuncEntryArgs2);
293
294 // If both `v1` and `v2` have a single matching terminal buffer, they are
295 // guaranteed to originate from the same buffer allocation.
296 if (llvm::hasSingleElement(C&: terminal1) && llvm::hasSingleElement(C&: terminal2) &&
297 *terminal1.begin() == *terminal2.begin())
298 return true;
299
300 // At least one of the two values has multiple terminals.
301
302 // Check if there is overlap between the terminal buffers of `v1` and `v2`.
303 bool distinctTerminalSets = true;
304 for (Value v : terminal1)
305 distinctTerminalSets &= !terminal2.contains(Ptr: v);
306 // If there is overlap between the terminal buffers of `v1` and `v2`, we
307 // cannot make an accurate decision without further analysis.
308 if (!distinctTerminalSets)
309 return std::nullopt;
310
311 // If `v1` originates from only allocs, and `v2` is guaranteed to originate
312 // from different allocations (that is guaranteed if `v2` originates from
313 // only distinct allocs or function entry arguments), we can be sure that
314 // `v1` and `v2` originate from different allocations. The same argument can
315 // be made when swapping `v1` and `v2`.
316 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
317 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
318 if (isolatedAlloc1 || isolatedAlloc2)
319 return false;
320
321 // Otherwise: We do not know whether `v1` and `v2` originate from the same
322 // allocation or not.
323 // TODO: Function arguments are currently handled conservatively. We assume
324 // that they could be the same allocation.
325 // TODO: Terminals other than allocations and function arguments are
326 // currently handled conservatively. We assume that they could be the same
327 // allocation. E.g., we currently return "nullopt" for values that originate
328 // from different "memref.get_global" ops (with different symbols).
329 return std::nullopt;
330}
331

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp