1 | //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the dataflow analysis class for integer range inference |
10 | // which is used in transformations over the `arith` dialect such as |
11 | // branch elimination or signed->unsigned rewriting |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
18 | #include "mlir/Analysis/DataFlowFramework.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Dialect.h" |
21 | #include "mlir/IR/OpDefinition.h" |
22 | #include "mlir/IR/Value.h" |
23 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
24 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
25 | #include "mlir/Interfaces/LoopLikeInterface.h" |
26 | #include "mlir/Support/LLVM.h" |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/Support/Casting.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include <cassert> |
31 | #include <optional> |
32 | #include <utility> |
33 | |
34 | #define DEBUG_TYPE "int-range-analysis" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::dataflow; |
38 | |
39 | IntegerValueRange IntegerValueRange::getMaxRange(Value value) { |
40 | unsigned width = ConstantIntRanges::getStorageBitwidth(type: value.getType()); |
41 | if (width == 0) |
42 | return {}; |
43 | APInt umin = APInt::getMinValue(numBits: width); |
44 | APInt umax = APInt::getMaxValue(numBits: width); |
45 | APInt smin = width != 0 ? APInt::getSignedMinValue(numBits: width) : umin; |
46 | APInt smax = width != 0 ? APInt::getSignedMaxValue(numBits: width) : umax; |
47 | return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}}; |
48 | } |
49 | |
50 | void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { |
51 | Lattice::onUpdate(solver); |
52 | |
53 | // If the integer range can be narrowed to a constant, update the constant |
54 | // value of the SSA value. |
55 | std::optional<APInt> constant = getValue().getValue().getConstantValue(); |
56 | auto value = point.get<Value>(); |
57 | auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); |
58 | if (!constant) |
59 | return solver->propagateIfChanged( |
60 | state: cv, changed: cv->join(ConstantValue::getUnknownConstant())); |
61 | |
62 | Dialect *dialect; |
63 | if (auto *parent = value.getDefiningOp()) |
64 | dialect = parent->getDialect(); |
65 | else |
66 | dialect = value.getParentBlock()->getParentOp()->getDialect(); |
67 | solver->propagateIfChanged( |
68 | cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), |
69 | dialect))); |
70 | } |
71 | |
72 | void IntegerRangeAnalysis::visitOperation( |
73 | Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, |
74 | ArrayRef<IntegerValueRangeLattice *> results) { |
75 | // If the lattice on any operand is unitialized, bail out. |
76 | if (llvm::any_of(Range&: operands, P: [](const IntegerValueRangeLattice *lattice) { |
77 | return lattice->getValue().isUninitialized(); |
78 | })) { |
79 | return; |
80 | } |
81 | |
82 | auto inferrable = dyn_cast<InferIntRangeInterface>(op); |
83 | if (!inferrable) |
84 | return setAllToEntryStates(results); |
85 | |
86 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
87 | SmallVector<ConstantIntRanges> argRanges( |
88 | llvm::map_range(operands, [](const IntegerValueRangeLattice *val) { |
89 | return val->getValue().getValue(); |
90 | })); |
91 | |
92 | auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { |
93 | auto result = dyn_cast<OpResult>(Val&: v); |
94 | if (!result) |
95 | return; |
96 | assert(llvm::is_contained(op->getResults(), result)); |
97 | |
98 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
99 | IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; |
100 | IntegerValueRange oldRange = lattice->getValue(); |
101 | |
102 | ChangeResult changed = lattice->join(IntegerValueRange{attrs}); |
103 | |
104 | // Catch loop results with loop variant bounds and conservatively make |
105 | // them [-inf, inf] so we don't circle around infinitely often (because |
106 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
107 | // and often can't). |
108 | bool isYieldedResult = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
109 | return op->hasTrait<OpTrait::IsTerminator>(); |
110 | }); |
111 | if (isYieldedResult && !oldRange.isUninitialized() && |
112 | !(lattice->getValue() == oldRange)) { |
113 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
114 | changed |= lattice->join(IntegerValueRange::getMaxRange(value: v)); |
115 | } |
116 | propagateIfChanged(lattice, changed); |
117 | }; |
118 | |
119 | inferrable.inferResultRanges(argRanges, joinCallback); |
120 | } |
121 | |
122 | void IntegerRangeAnalysis::visitNonControlFlowArguments( |
123 | Operation *op, const RegionSuccessor &successor, |
124 | ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { |
125 | if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { |
126 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
127 | // If the lattice on any operand is unitialized, bail out. |
128 | if (llvm::any_of(Range: op->getOperands(), P: [&](Value value) { |
129 | return getLatticeElementFor(op, value)->getValue().isUninitialized(); |
130 | })) |
131 | return; |
132 | SmallVector<ConstantIntRanges> argRanges( |
133 | llvm::map_range(op->getOperands(), [&](Value value) { |
134 | return getLatticeElementFor(op, value)->getValue().getValue(); |
135 | })); |
136 | |
137 | auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { |
138 | auto arg = dyn_cast<BlockArgument>(Val&: v); |
139 | if (!arg) |
140 | return; |
141 | if (!llvm::is_contained(Range: successor.getSuccessor()->getArguments(), Element: arg)) |
142 | return; |
143 | |
144 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
145 | IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; |
146 | IntegerValueRange oldRange = lattice->getValue(); |
147 | |
148 | ChangeResult changed = lattice->join(IntegerValueRange{attrs}); |
149 | |
150 | // Catch loop results with loop variant bounds and conservatively make |
151 | // them [-inf, inf] so we don't circle around infinitely often (because |
152 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
153 | // and often can't). |
154 | bool isYieldedValue = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
155 | return op->hasTrait<OpTrait::IsTerminator>(); |
156 | }); |
157 | if (isYieldedValue && !oldRange.isUninitialized() && |
158 | !(lattice->getValue() == oldRange)) { |
159 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
160 | changed |= lattice->join(IntegerValueRange::getMaxRange(value: v)); |
161 | } |
162 | propagateIfChanged(lattice, changed); |
163 | }; |
164 | |
165 | inferrable.inferResultRanges(argRanges, joinCallback); |
166 | return; |
167 | } |
168 | |
169 | /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() |
170 | /// on a LoopLikeInterface return the lower/upper bound for that result if |
171 | /// possible. |
172 | auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, |
173 | Type boundType, bool getUpper) { |
174 | unsigned int width = ConstantIntRanges::getStorageBitwidth(type: boundType); |
175 | if (loopBound.has_value()) { |
176 | if (loopBound->is<Attribute>()) { |
177 | if (auto bound = |
178 | dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>())) |
179 | return bound.getValue(); |
180 | } else if (auto value = llvm::dyn_cast_if_present<Value>(Val&: *loopBound)) { |
181 | const IntegerValueRangeLattice *lattice = |
182 | getLatticeElementFor(op, value); |
183 | if (lattice != nullptr && !lattice->getValue().isUninitialized()) |
184 | return getUpper ? lattice->getValue().getValue().smax() |
185 | : lattice->getValue().getValue().smin(); |
186 | } |
187 | } |
188 | // Given the results of getConstant{Lower,Upper}Bound() |
189 | // or getConstantStep() on a LoopLikeInterface return the lower/upper |
190 | // bound |
191 | return getUpper ? APInt::getSignedMaxValue(width) |
192 | : APInt::getSignedMinValue(width); |
193 | }; |
194 | |
195 | // Infer bounds for loop arguments that have static bounds |
196 | if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { |
197 | std::optional<Value> iv = loop.getSingleInductionVar(); |
198 | if (!iv) { |
199 | return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( |
200 | op, successor, argLattices, firstIndex); |
201 | } |
202 | std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); |
203 | std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); |
204 | std::optional<OpFoldResult> step = loop.getSingleStep(); |
205 | APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), |
206 | /*getUpper=*/false); |
207 | APInt max = getLoopBoundFromFold(upperBound, iv->getType(), |
208 | /*getUpper=*/true); |
209 | // Assume positivity for uniscoverable steps by way of getUpper = true. |
210 | APInt stepVal = |
211 | getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); |
212 | |
213 | if (stepVal.isNegative()) { |
214 | std::swap(a&: min, b&: max); |
215 | } else { |
216 | // Correct the upper bound by subtracting 1 so that it becomes a <= |
217 | // bound, because loops do not generally include their upper bound. |
218 | max -= 1; |
219 | } |
220 | |
221 | IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); |
222 | auto ivRange = ConstantIntRanges::fromSigned(smin: min, smax: max); |
223 | propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); |
224 | return; |
225 | } |
226 | |
227 | return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( |
228 | op, successor, argLattices, firstIndex); |
229 | } |
230 | |