1//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the dataflow analysis class for integer range inference
10// which is used in transformations over the `arith` dialect such as
11// branch elimination or signed->unsigned rewriting
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18#include "mlir/Analysis/DataFlowFramework.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Dialect.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/TypeUtilities.h"
23#include "mlir/IR/Value.h"
24#include "mlir/Interfaces/ControlFlowInterfaces.h"
25#include "mlir/Interfaces/InferIntRangeInterface.h"
26#include "mlir/Interfaces/LoopLikeInterface.h"
27#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/Casting.h"
30#include "llvm/Support/Debug.h"
31#include <cassert>
32#include <optional>
33#include <utility>
34
35#define DEBUG_TYPE "int-range-analysis"
36
37using namespace mlir;
38using namespace mlir::dataflow;
39
40namespace mlir::dataflow {
41LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
42 auto *result = solver.lookupState<IntegerValueRangeLattice>(anchor: v);
43 if (!result || result->getValue().isUninitialized())
44 return failure();
45 const ConstantIntRanges &range = result->getValue().getValue();
46 return success(IsSuccess: range.smin().isNonNegative());
47}
48
49LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
50 auto nonNegativePred = [&solver](Value v) -> bool {
51 return succeeded(Result: staticallyNonNegative(solver, v));
52 };
53 return success(IsSuccess: llvm::all_of(Range: op->getOperands(), P: nonNegativePred) &&
54 llvm::all_of(Range: op->getResults(), P: nonNegativePred));
55}
56} // namespace mlir::dataflow
57
58void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
59 Lattice::onUpdate(solver);
60
61 // If the integer range can be narrowed to a constant, update the constant
62 // value of the SSA value.
63 std::optional<APInt> constant = getValue().getValue().getConstantValue();
64 auto value = cast<Value>(Val: anchor);
65 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(anchor: value);
66 if (!constant)
67 return solver->propagateIfChanged(
68 state: cv, changed: cv->join(rhs: ConstantValue::getUnknownConstant()));
69
70 Dialect *dialect;
71 if (auto *parent = value.getDefiningOp())
72 dialect = parent->getDialect();
73 else
74 dialect = value.getParentBlock()->getParentOp()->getDialect();
75
76 Type type = getElementTypeOrSelf(val: value);
77 solver->propagateIfChanged(
78 cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
79}
80
81LogicalResult IntegerRangeAnalysis::visitOperation(
82 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
83 ArrayRef<IntegerValueRangeLattice *> results) {
84 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
85 if (!inferrable) {
86 setAllToEntryStates(results);
87 return success();
88 }
89
90 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
91 auto argRanges = llvm::map_to_vector(
92 C&: operands, F: [](const IntegerValueRangeLattice *lattice) {
93 return lattice->getValue();
94 });
95
96 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
97 auto result = dyn_cast<OpResult>(Val&: v);
98 if (!result)
99 return;
100 assert(llvm::is_contained(op->getResults(), result));
101
102 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
103 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
104 IntegerValueRange oldRange = lattice->getValue();
105
106 ChangeResult changed = lattice->join(rhs: attrs);
107
108 // Catch loop results with loop variant bounds and conservatively make
109 // them [-inf, inf] so we don't circle around infinitely often (because
110 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
111 // and often can't).
112 bool isYieldedResult = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) {
113 return op->hasTrait<OpTrait::IsTerminator>();
114 });
115 if (isYieldedResult && !oldRange.isUninitialized() &&
116 !(lattice->getValue() == oldRange)) {
117 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
118 changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v));
119 }
120 propagateIfChanged(lattice, changed);
121 };
122
123 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
124 return success();
125}
126
127void IntegerRangeAnalysis::visitNonControlFlowArguments(
128 Operation *op, const RegionSuccessor &successor,
129 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
130 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
131 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
132
133 auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
134 return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
135 });
136
137 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
138 auto arg = dyn_cast<BlockArgument>(Val&: v);
139 if (!arg)
140 return;
141 if (!llvm::is_contained(Range: successor.getSuccessor()->getArguments(), Element: arg))
142 return;
143
144 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
145 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
146 IntegerValueRange oldRange = lattice->getValue();
147
148 ChangeResult changed = lattice->join(rhs: attrs);
149
150 // Catch loop results with loop variant bounds and conservatively make
151 // them [-inf, inf] so we don't circle around infinitely often (because
152 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
153 // and often can't).
154 bool isYieldedValue = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) {
155 return op->hasTrait<OpTrait::IsTerminator>();
156 });
157 if (isYieldedValue && !oldRange.isUninitialized() &&
158 !(lattice->getValue() == oldRange)) {
159 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
160 changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v));
161 }
162 propagateIfChanged(lattice, changed);
163 };
164
165 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
166 return;
167 }
168
169 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
170 /// on a LoopLikeInterface return the lower/upper bound for that result if
171 /// possible.
172 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
173 Type boundType, Block *block, bool getUpper) {
174 unsigned int width = ConstantIntRanges::getStorageBitwidth(type: boundType);
175 if (loopBound.has_value()) {
176 if (auto attr = dyn_cast<Attribute>(Val&: *loopBound)) {
177 if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
178 return bound.getValue();
179 } else if (auto value = llvm::dyn_cast_if_present<Value>(Val&: *loopBound)) {
180 const IntegerValueRangeLattice *lattice =
181 getLatticeElementFor(getProgramPointBefore(block), value);
182 if (lattice != nullptr && !lattice->getValue().isUninitialized())
183 return getUpper ? lattice->getValue().getValue().smax()
184 : lattice->getValue().getValue().smin();
185 }
186 }
187 // Given the results of getConstant{Lower,Upper}Bound()
188 // or getConstantStep() on a LoopLikeInterface return the lower/upper
189 // bound
190 return getUpper ? APInt::getSignedMaxValue(width)
191 : APInt::getSignedMinValue(width);
192 };
193
194 // Infer bounds for loop arguments that have static bounds
195 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
196 std::optional<Value> iv = loop.getSingleInductionVar();
197 if (!iv) {
198 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
199 op, successor, argLattices, firstIndex);
200 }
201 Block *block = iv->getParentBlock();
202 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
203 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
204 std::optional<OpFoldResult> step = loop.getSingleStep();
205 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
206 /*getUpper=*/false);
207 APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
208 /*getUpper=*/true);
209 // Assume positivity for uniscoverable steps by way of getUpper = true.
210 APInt stepVal =
211 getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
212
213 if (stepVal.isNegative()) {
214 std::swap(a&: min, b&: max);
215 } else {
216 // Correct the upper bound by subtracting 1 so that it becomes a <=
217 // bound, because loops do not generally include their upper bound.
218 max -= 1;
219 }
220
221 // If we infer the lower bound to be larger than the upper bound, the
222 // resulting range is meaningless and should not be used in further
223 // inferences.
224 if (max.sge(RHS: min)) {
225 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
226 auto ivRange = ConstantIntRanges::fromSigned(smin: min, smax: max);
227 propagateIfChanged(ivEntry, ivEntry->join(rhs: IntegerValueRange{ivRange}));
228 }
229 return;
230 }
231
232 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
233 op, successor, argLattices, firstIndex);
234}
235

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp