1 | //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the dataflow analysis class for integer range inference |
10 | // which is used in transformations over the `arith` dialect such as |
11 | // branch elimination or signed->unsigned rewriting |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
18 | #include "mlir/Analysis/DataFlowFramework.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Dialect.h" |
21 | #include "mlir/IR/OpDefinition.h" |
22 | #include "mlir/IR/TypeUtilities.h" |
23 | #include "mlir/IR/Value.h" |
24 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
25 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
26 | #include "mlir/Interfaces/LoopLikeInterface.h" |
27 | #include "mlir/Support/LLVM.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/Support/Casting.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include <cassert> |
32 | #include <optional> |
33 | #include <utility> |
34 | |
35 | #define DEBUG_TYPE "int-range-analysis" |
36 | |
37 | using namespace mlir; |
38 | using namespace mlir::dataflow; |
39 | |
40 | namespace mlir::dataflow { |
41 | LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { |
42 | auto *result = solver.lookupState<IntegerValueRangeLattice>(anchor: v); |
43 | if (!result || result->getValue().isUninitialized()) |
44 | return failure(); |
45 | const ConstantIntRanges &range = result->getValue().getValue(); |
46 | return success(IsSuccess: range.smin().isNonNegative()); |
47 | } |
48 | |
49 | LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { |
50 | auto nonNegativePred = [&solver](Value v) -> bool { |
51 | return succeeded(Result: staticallyNonNegative(solver, v)); |
52 | }; |
53 | return success(IsSuccess: llvm::all_of(Range: op->getOperands(), P: nonNegativePred) && |
54 | llvm::all_of(Range: op->getResults(), P: nonNegativePred)); |
55 | } |
56 | } // namespace mlir::dataflow |
57 | |
58 | void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { |
59 | Lattice::onUpdate(solver); |
60 | |
61 | // If the integer range can be narrowed to a constant, update the constant |
62 | // value of the SSA value. |
63 | std::optional<APInt> constant = getValue().getValue().getConstantValue(); |
64 | auto value = cast<Value>(Val: anchor); |
65 | auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(anchor: value); |
66 | if (!constant) |
67 | return solver->propagateIfChanged( |
68 | state: cv, changed: cv->join(rhs: ConstantValue::getUnknownConstant())); |
69 | |
70 | Dialect *dialect; |
71 | if (auto *parent = value.getDefiningOp()) |
72 | dialect = parent->getDialect(); |
73 | else |
74 | dialect = value.getParentBlock()->getParentOp()->getDialect(); |
75 | |
76 | Type type = getElementTypeOrSelf(val: value); |
77 | solver->propagateIfChanged( |
78 | cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); |
79 | } |
80 | |
81 | LogicalResult IntegerRangeAnalysis::visitOperation( |
82 | Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, |
83 | ArrayRef<IntegerValueRangeLattice *> results) { |
84 | auto inferrable = dyn_cast<InferIntRangeInterface>(op); |
85 | if (!inferrable) { |
86 | setAllToEntryStates(results); |
87 | return success(); |
88 | } |
89 | |
90 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
91 | auto argRanges = llvm::map_to_vector( |
92 | C&: operands, F: [](const IntegerValueRangeLattice *lattice) { |
93 | return lattice->getValue(); |
94 | }); |
95 | |
96 | auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
97 | auto result = dyn_cast<OpResult>(Val&: v); |
98 | if (!result) |
99 | return; |
100 | assert(llvm::is_contained(op->getResults(), result)); |
101 | |
102 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
103 | IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; |
104 | IntegerValueRange oldRange = lattice->getValue(); |
105 | |
106 | ChangeResult changed = lattice->join(rhs: attrs); |
107 | |
108 | // Catch loop results with loop variant bounds and conservatively make |
109 | // them [-inf, inf] so we don't circle around infinitely often (because |
110 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
111 | // and often can't). |
112 | bool isYieldedResult = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
113 | return op->hasTrait<OpTrait::IsTerminator>(); |
114 | }); |
115 | if (isYieldedResult && !oldRange.isUninitialized() && |
116 | !(lattice->getValue() == oldRange)) { |
117 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
118 | changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v)); |
119 | } |
120 | propagateIfChanged(lattice, changed); |
121 | }; |
122 | |
123 | inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
124 | return success(); |
125 | } |
126 | |
127 | void IntegerRangeAnalysis::visitNonControlFlowArguments( |
128 | Operation *op, const RegionSuccessor &successor, |
129 | ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { |
130 | if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { |
131 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
132 | |
133 | auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { |
134 | return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); |
135 | }); |
136 | |
137 | auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
138 | auto arg = dyn_cast<BlockArgument>(Val&: v); |
139 | if (!arg) |
140 | return; |
141 | if (!llvm::is_contained(Range: successor.getSuccessor()->getArguments(), Element: arg)) |
142 | return; |
143 | |
144 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
145 | IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; |
146 | IntegerValueRange oldRange = lattice->getValue(); |
147 | |
148 | ChangeResult changed = lattice->join(rhs: attrs); |
149 | |
150 | // Catch loop results with loop variant bounds and conservatively make |
151 | // them [-inf, inf] so we don't circle around infinitely often (because |
152 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
153 | // and often can't). |
154 | bool isYieldedValue = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
155 | return op->hasTrait<OpTrait::IsTerminator>(); |
156 | }); |
157 | if (isYieldedValue && !oldRange.isUninitialized() && |
158 | !(lattice->getValue() == oldRange)) { |
159 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
160 | changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v)); |
161 | } |
162 | propagateIfChanged(lattice, changed); |
163 | }; |
164 | |
165 | inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
166 | return; |
167 | } |
168 | |
169 | /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() |
170 | /// on a LoopLikeInterface return the lower/upper bound for that result if |
171 | /// possible. |
172 | auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, |
173 | Type boundType, Block *block, bool getUpper) { |
174 | unsigned int width = ConstantIntRanges::getStorageBitwidth(type: boundType); |
175 | if (loopBound.has_value()) { |
176 | if (auto attr = dyn_cast<Attribute>(Val&: *loopBound)) { |
177 | if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) |
178 | return bound.getValue(); |
179 | } else if (auto value = llvm::dyn_cast_if_present<Value>(Val&: *loopBound)) { |
180 | const IntegerValueRangeLattice *lattice = |
181 | getLatticeElementFor(getProgramPointBefore(block), value); |
182 | if (lattice != nullptr && !lattice->getValue().isUninitialized()) |
183 | return getUpper ? lattice->getValue().getValue().smax() |
184 | : lattice->getValue().getValue().smin(); |
185 | } |
186 | } |
187 | // Given the results of getConstant{Lower,Upper}Bound() |
188 | // or getConstantStep() on a LoopLikeInterface return the lower/upper |
189 | // bound |
190 | return getUpper ? APInt::getSignedMaxValue(width) |
191 | : APInt::getSignedMinValue(width); |
192 | }; |
193 | |
194 | // Infer bounds for loop arguments that have static bounds |
195 | if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { |
196 | std::optional<Value> iv = loop.getSingleInductionVar(); |
197 | if (!iv) { |
198 | return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( |
199 | op, successor, argLattices, firstIndex); |
200 | } |
201 | Block *block = iv->getParentBlock(); |
202 | std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); |
203 | std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); |
204 | std::optional<OpFoldResult> step = loop.getSingleStep(); |
205 | APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, |
206 | /*getUpper=*/false); |
207 | APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, |
208 | /*getUpper=*/true); |
209 | // Assume positivity for uniscoverable steps by way of getUpper = true. |
210 | APInt stepVal = |
211 | getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); |
212 | |
213 | if (stepVal.isNegative()) { |
214 | std::swap(a&: min, b&: max); |
215 | } else { |
216 | // Correct the upper bound by subtracting 1 so that it becomes a <= |
217 | // bound, because loops do not generally include their upper bound. |
218 | max -= 1; |
219 | } |
220 | |
221 | // If we infer the lower bound to be larger than the upper bound, the |
222 | // resulting range is meaningless and should not be used in further |
223 | // inferences. |
224 | if (max.sge(RHS: min)) { |
225 | IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); |
226 | auto ivRange = ConstantIntRanges::fromSigned(smin: min, smax: max); |
227 | propagateIfChanged(ivEntry, ivEntry->join(rhs: IntegerValueRange{ivRange})); |
228 | } |
229 | return; |
230 | } |
231 | |
232 | return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( |
233 | op, successor, argLattices, firstIndex); |
234 | } |
235 | |