1 | //===- LocalAliasAnalysis.cpp - Local stateless alias Analysis for MLIR ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" |
10 | |
11 | #include "mlir/Analysis/AliasAnalysis.h" |
12 | #include "mlir/IR/Attributes.h" |
13 | #include "mlir/IR/Block.h" |
14 | #include "mlir/IR/Matchers.h" |
15 | #include "mlir/IR/OpDefinition.h" |
16 | #include "mlir/IR/Operation.h" |
17 | #include "mlir/IR/Region.h" |
18 | #include "mlir/IR/Value.h" |
19 | #include "mlir/IR/ValueRange.h" |
20 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
21 | #include "mlir/Interfaces/FunctionInterfaces.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Interfaces/ViewLikeInterface.h" |
24 | #include "mlir/Support/LLVM.h" |
25 | #include "llvm/Support/Casting.h" |
26 | #include <cassert> |
27 | #include <optional> |
28 | #include <utility> |
29 | |
30 | using namespace mlir; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Underlying Address Computation |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | /// The maximum depth that will be searched when trying to find an underlying |
37 | /// value. |
38 | static constexpr unsigned maxUnderlyingValueSearchDepth = 10; |
39 | |
40 | /// Given a value, collect all of the underlying values being addressed. |
41 | static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, |
42 | DenseSet<Value> &visited, |
43 | SmallVectorImpl<Value> &output); |
44 | |
45 | /// Given a successor (`region`) of a RegionBranchOpInterface, collect all of |
46 | /// the underlying values being addressed by one of the successor inputs. If the |
47 | /// provided `region` is null, as per `RegionBranchOpInterface` this represents |
48 | /// the parent operation. |
49 | static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, |
50 | Region *region, Value inputValue, |
51 | unsigned inputIndex, |
52 | unsigned maxDepth, |
53 | DenseSet<Value> &visited, |
54 | SmallVectorImpl<Value> &output) { |
55 | // Given the index of a region of the branch (`predIndex`), or std::nullopt to |
56 | // represent the parent operation, try to return the index into the outputs of |
57 | // this region predecessor that correspond to the input values of `region`. If |
58 | // an index could not be found, std::nullopt is returned instead. |
59 | auto getOperandIndexIfPred = |
60 | [&](RegionBranchPoint pred) -> std::optional<unsigned> { |
61 | SmallVector<RegionSuccessor, 2> successors; |
62 | branch.getSuccessorRegions(pred, successors); |
63 | for (RegionSuccessor &successor : successors) { |
64 | if (successor.getSuccessor() != region) |
65 | continue; |
66 | // Check that the successor inputs map to the given input value. |
67 | ValueRange inputs = successor.getSuccessorInputs(); |
68 | if (inputs.empty()) { |
69 | output.push_back(Elt: inputValue); |
70 | break; |
71 | } |
72 | unsigned firstInputIndex, lastInputIndex; |
73 | if (region) { |
74 | firstInputIndex = cast<BlockArgument>(Val: inputs[0]).getArgNumber(); |
75 | lastInputIndex = cast<BlockArgument>(Val: inputs.back()).getArgNumber(); |
76 | } else { |
77 | firstInputIndex = cast<OpResult>(Val: inputs[0]).getResultNumber(); |
78 | lastInputIndex = cast<OpResult>(Val: inputs.back()).getResultNumber(); |
79 | } |
80 | if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { |
81 | output.push_back(Elt: inputValue); |
82 | break; |
83 | } |
84 | return inputIndex - firstInputIndex; |
85 | } |
86 | return std::nullopt; |
87 | }; |
88 | |
89 | // Check branches from the parent operation. |
90 | auto branchPoint = RegionBranchPoint::parent(); |
91 | if (region) |
92 | branchPoint = region; |
93 | |
94 | if (std::optional<unsigned> operandIndex = |
95 | getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) { |
96 | collectUnderlyingAddressValues( |
97 | branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth, |
98 | visited, output); |
99 | } |
100 | // Check branches from each child region. |
101 | Operation *op = branch.getOperation(); |
102 | for (Region ®ion : op->getRegions()) { |
103 | if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) { |
104 | for (Block &block : region) { |
105 | // Try to determine possible region-branch successor operands for the |
106 | // current region. |
107 | if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>( |
108 | block.getTerminator())) { |
109 | collectUnderlyingAddressValues( |
110 | term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth, |
111 | visited, output); |
112 | } else if (block.getNumSuccessors()) { |
113 | // Otherwise, if this terminator may exit the region we can't make |
114 | // any assumptions about which values get passed. |
115 | output.push_back(inputValue); |
116 | return; |
117 | } |
118 | } |
119 | } |
120 | } |
121 | } |
122 | |
123 | /// Given a result, collect all of the underlying values being addressed. |
124 | static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, |
125 | DenseSet<Value> &visited, |
126 | SmallVectorImpl<Value> &output) { |
127 | Operation *op = result.getOwner(); |
128 | |
129 | // If this is a view, unwrap to the source. |
130 | if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) |
131 | return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, |
132 | visited, output); |
133 | // Check to see if we can reason about the control flow of this op. |
134 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { |
135 | return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, |
136 | result.getResultNumber(), maxDepth, |
137 | visited, output); |
138 | } |
139 | |
140 | output.push_back(Elt: result); |
141 | } |
142 | |
143 | /// Given a block argument, collect all of the underlying values being |
144 | /// addressed. |
145 | static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, |
146 | DenseSet<Value> &visited, |
147 | SmallVectorImpl<Value> &output) { |
148 | Block *block = arg.getOwner(); |
149 | unsigned argNumber = arg.getArgNumber(); |
150 | |
151 | // Handle the case of a non-entry block. |
152 | if (!block->isEntryBlock()) { |
153 | for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { |
154 | auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator()); |
155 | if (!branch) { |
156 | // We can't analyze the control flow, so bail out early. |
157 | output.push_back(Elt: arg); |
158 | return; |
159 | } |
160 | |
161 | // Try to get the operand passed for this argument. |
162 | unsigned index = it.getSuccessorIndex(); |
163 | Value operand = branch.getSuccessorOperands(index)[argNumber]; |
164 | if (!operand) { |
165 | // We can't analyze the control flow, so bail out early. |
166 | output.push_back(Elt: arg); |
167 | return; |
168 | } |
169 | collectUnderlyingAddressValues(value: operand, maxDepth, visited, output); |
170 | } |
171 | return; |
172 | } |
173 | |
174 | // Otherwise, check to see if we can reason about the control flow of this op. |
175 | Region *region = block->getParent(); |
176 | Operation *op = region->getParentOp(); |
177 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { |
178 | return collectUnderlyingAddressValues(branch, region, arg, argNumber, |
179 | maxDepth, visited, output); |
180 | } |
181 | |
182 | // We can't reason about the underlying address of this argument. |
183 | output.push_back(Elt: arg); |
184 | } |
185 | |
186 | /// Given a value, collect all of the underlying values being addressed. |
187 | static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, |
188 | DenseSet<Value> &visited, |
189 | SmallVectorImpl<Value> &output) { |
190 | // Check that we don't infinitely recurse. |
191 | if (!visited.insert(V: value).second) |
192 | return; |
193 | if (maxDepth == 0) { |
194 | output.push_back(Elt: value); |
195 | return; |
196 | } |
197 | --maxDepth; |
198 | |
199 | if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) |
200 | return collectUnderlyingAddressValues(arg, maxDepth, visited, output); |
201 | collectUnderlyingAddressValues(result: cast<OpResult>(Val&: value), maxDepth, visited, |
202 | output); |
203 | } |
204 | |
205 | /// Given a value, collect all of the underlying values being addressed. |
206 | static void collectUnderlyingAddressValues(Value value, |
207 | SmallVectorImpl<Value> &output) { |
208 | DenseSet<Value> visited; |
209 | collectUnderlyingAddressValues(value, maxDepth: maxUnderlyingValueSearchDepth, visited, |
210 | output); |
211 | } |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // LocalAliasAnalysis: alias |
215 | //===----------------------------------------------------------------------===// |
216 | |
217 | /// Given a value, try to get an allocation effect attached to it. If |
218 | /// successful, `allocEffect` is populated with the effect. If an effect was |
219 | /// found, `allocScopeOp` is also specified if a parent operation of `value` |
220 | /// could be identified that bounds the scope of the allocated value; i.e. if |
221 | /// non-null it specifies the parent operation that the allocation does not |
222 | /// escape. If no scope is found, `allocScopeOp` is set to nullptr. |
223 | static LogicalResult |
224 | getAllocEffectFor(Value value, |
225 | std::optional<MemoryEffects::EffectInstance> &effect, |
226 | Operation *&allocScopeOp) { |
227 | // Try to get a memory effect interface for the parent operation. |
228 | Operation *op; |
229 | if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) |
230 | op = arg.getOwner()->getParentOp(); |
231 | else |
232 | op = cast<OpResult>(Val&: value).getOwner(); |
233 | MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); |
234 | if (!interface) |
235 | return failure(); |
236 | |
237 | // Try to find an allocation effect on the resource. |
238 | if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) |
239 | return failure(); |
240 | |
241 | // If we found an allocation effect, try to find a scope for the allocation. |
242 | // If the resource of this allocation is automatically scoped, find the parent |
243 | // operation that bounds the allocation scope. |
244 | if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>( |
245 | Val: effect->getResource())) { |
246 | allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); |
247 | return success(); |
248 | } |
249 | |
250 | // TODO: Here we could look at the users to see if the resource is either |
251 | // freed on all paths within the region, or is just not captured by anything. |
252 | // For now assume allocation scope to the function scope (we don't care if |
253 | // pointer escape outside function). |
254 | allocScopeOp = op->getParentOfType<FunctionOpInterface>(); |
255 | return success(); |
256 | } |
257 | |
258 | /// Given the two values, return their aliasing behavior. |
259 | AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { |
260 | if (lhs == rhs) |
261 | return AliasResult::MustAlias; |
262 | Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr; |
263 | std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc; |
264 | |
265 | // Handle the case where lhs is a constant. |
266 | Attribute lhsAttr, rhsAttr; |
267 | if (matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr))) { |
268 | // TODO: This is overly conservative. Two matching constants don't |
269 | // necessarily map to the same address. For example, if the two values |
270 | // correspond to different symbols that both represent a definition. |
271 | if (matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr))) |
272 | return AliasResult::MayAlias; |
273 | |
274 | // Try to find an alloc effect on rhs. If an effect was found we can't |
275 | // alias, otherwise we might. |
276 | return succeeded(Result: getAllocEffectFor(value: rhs, effect&: rhsAlloc, allocScopeOp&: rhsAllocScope)) |
277 | ? AliasResult::NoAlias |
278 | : AliasResult::MayAlias; |
279 | } |
280 | // Handle the case where rhs is a constant. |
281 | if (matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr))) { |
282 | // Try to find an alloc effect on lhs. If an effect was found we can't |
283 | // alias, otherwise we might. |
284 | return succeeded(Result: getAllocEffectFor(value: lhs, effect&: lhsAlloc, allocScopeOp&: lhsAllocScope)) |
285 | ? AliasResult::NoAlias |
286 | : AliasResult::MayAlias; |
287 | } |
288 | |
289 | // Otherwise, neither of the values are constant so check to see if either has |
290 | // an allocation effect. |
291 | bool lhsHasAlloc = succeeded(Result: getAllocEffectFor(value: lhs, effect&: lhsAlloc, allocScopeOp&: lhsAllocScope)); |
292 | bool rhsHasAlloc = succeeded(Result: getAllocEffectFor(value: rhs, effect&: rhsAlloc, allocScopeOp&: rhsAllocScope)); |
293 | if (lhsHasAlloc == rhsHasAlloc) { |
294 | // If both values have an allocation effect we know they don't alias, and if |
295 | // neither have an effect we can't make an assumptions. |
296 | return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; |
297 | } |
298 | |
299 | // When we reach this point we have one value with a known allocation effect, |
300 | // and one without. Move the one with the effect to the lhs to make the next |
301 | // checks simpler. |
302 | if (rhsHasAlloc) { |
303 | std::swap(a&: lhs, b&: rhs); |
304 | lhsAlloc = rhsAlloc; |
305 | lhsAllocScope = rhsAllocScope; |
306 | } |
307 | |
308 | // If the effect has a scoped allocation region, check to see if the |
309 | // non-effect value is defined above that scope. |
310 | if (lhsAllocScope) { |
311 | // If the parent operation of rhs is an ancestor of the allocation scope, or |
312 | // if rhs is an entry block argument of the allocation scope we know the two |
313 | // values can't alias. |
314 | Operation *rhsParentOp = rhs.getParentRegion()->getParentOp(); |
315 | if (rhsParentOp->isProperAncestor(other: lhsAllocScope)) |
316 | return AliasResult::NoAlias; |
317 | if (rhsParentOp == lhsAllocScope) { |
318 | BlockArgument rhsArg = dyn_cast<BlockArgument>(Val&: rhs); |
319 | if (rhsArg && rhs.getParentBlock()->isEntryBlock()) |
320 | return AliasResult::NoAlias; |
321 | } |
322 | } |
323 | |
324 | // If we couldn't reason about the relationship between the two values, |
325 | // conservatively assume they might alias. |
326 | return AliasResult::MayAlias; |
327 | } |
328 | |
329 | /// Given the two values, return their aliasing behavior. |
330 | AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { |
331 | if (lhs == rhs) |
332 | return AliasResult::MustAlias; |
333 | |
334 | // Get the underlying values being addressed. |
335 | SmallVector<Value, 8> lhsValues, rhsValues; |
336 | collectUnderlyingAddressValues(value: lhs, output&: lhsValues); |
337 | collectUnderlyingAddressValues(value: rhs, output&: rhsValues); |
338 | |
339 | // If we failed to collect for either of the values somehow, conservatively |
340 | // assume they may alias. |
341 | if (lhsValues.empty() || rhsValues.empty()) |
342 | return AliasResult::MayAlias; |
343 | |
344 | // Check the alias results against each of the underlying values. |
345 | std::optional<AliasResult> result; |
346 | for (Value lhsVal : lhsValues) { |
347 | for (Value rhsVal : rhsValues) { |
348 | AliasResult nextResult = aliasImpl(lhs: lhsVal, rhs: rhsVal); |
349 | result = result ? result->merge(other: nextResult) : nextResult; |
350 | } |
351 | } |
352 | |
353 | // We should always have a valid result here. |
354 | return *result; |
355 | } |
356 | |
357 | //===----------------------------------------------------------------------===// |
358 | // LocalAliasAnalysis: getModRef |
359 | //===----------------------------------------------------------------------===// |
360 | |
361 | ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { |
362 | // Check to see if this operation relies on nested side effects. |
363 | if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) { |
364 | // TODO: To check recursive operations we need to check all of the nested |
365 | // operations, which can result in a quadratic number of queries. We should |
366 | // introduce some caching of some kind to help alleviate this, especially as |
367 | // this caching could be used in other areas of the codebase (e.g. when |
368 | // checking `wouldOpBeTriviallyDead`). |
369 | return ModRefResult::getModAndRef(); |
370 | } |
371 | |
372 | // Otherwise, check to see if this operation has a memory effect interface. |
373 | MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); |
374 | if (!interface) |
375 | return ModRefResult::getModAndRef(); |
376 | |
377 | // Build a ModRefResult by merging the behavior of the effects of this |
378 | // operation. |
379 | SmallVector<MemoryEffects::EffectInstance> effects; |
380 | interface.getEffects(effects); |
381 | |
382 | ModRefResult result = ModRefResult::getNoModRef(); |
383 | for (const MemoryEffects::EffectInstance &effect : effects) { |
384 | if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(Val: effect.getEffect())) |
385 | continue; |
386 | |
387 | // Check for an alias between the effect and our memory location. |
388 | // TODO: Add support for checking an alias with a symbol reference. |
389 | AliasResult aliasResult = AliasResult::MayAlias; |
390 | if (Value effectValue = effect.getValue()) |
391 | aliasResult = alias(lhs: effectValue, rhs: location); |
392 | |
393 | // If we don't alias, ignore this effect. |
394 | if (aliasResult.isNo()) |
395 | continue; |
396 | |
397 | // Merge in the corresponding mod or ref for this effect. |
398 | if (isa<MemoryEffects::Read>(Val: effect.getEffect())) { |
399 | result = result.merge(other: ModRefResult::getRef()); |
400 | } else { |
401 | assert(isa<MemoryEffects::Write>(effect.getEffect())); |
402 | result = result.merge(other: ModRefResult::getMod()); |
403 | } |
404 | if (result.isModAndRef()) |
405 | break; |
406 | } |
407 | return result; |
408 | } |
409 | |