1//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the dataflow analysis class for integer range inference
10// which is used in transformations over the `arith` dialect such as
11// branch elimination or signed->unsigned rewriting
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
18#include "mlir/Analysis/DataFlowFramework.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Dialect.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/Value.h"
23#include "mlir/Interfaces/ControlFlowInterfaces.h"
24#include "mlir/Interfaces/InferIntRangeInterface.h"
25#include "mlir/Interfaces/LoopLikeInterface.h"
26#include "mlir/Support/LLVM.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/Debug.h"
30#include <cassert>
31#include <optional>
32#include <utility>
33
34#define DEBUG_TYPE "int-range-analysis"
35
36using namespace mlir;
37using namespace mlir::dataflow;
38
39IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
40 unsigned width = ConstantIntRanges::getStorageBitwidth(type: value.getType());
41 if (width == 0)
42 return {};
43 APInt umin = APInt::getMinValue(numBits: width);
44 APInt umax = APInt::getMaxValue(numBits: width);
45 APInt smin = width != 0 ? APInt::getSignedMinValue(numBits: width) : umin;
46 APInt smax = width != 0 ? APInt::getSignedMaxValue(numBits: width) : umax;
47 return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
48}
49
50void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
51 Lattice::onUpdate(solver);
52
53 // If the integer range can be narrowed to a constant, update the constant
54 // value of the SSA value.
55 std::optional<APInt> constant = getValue().getValue().getConstantValue();
56 auto value = point.get<Value>();
57 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
58 if (!constant)
59 return solver->propagateIfChanged(
60 state: cv, changed: cv->join(ConstantValue::getUnknownConstant()));
61
62 Dialect *dialect;
63 if (auto *parent = value.getDefiningOp())
64 dialect = parent->getDialect();
65 else
66 dialect = value.getParentBlock()->getParentOp()->getDialect();
67 solver->propagateIfChanged(
68 cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
69 dialect)));
70}
71
72void IntegerRangeAnalysis::visitOperation(
73 Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
74 ArrayRef<IntegerValueRangeLattice *> results) {
75 // If the lattice on any operand is unitialized, bail out.
76 if (llvm::any_of(Range&: operands, P: [](const IntegerValueRangeLattice *lattice) {
77 return lattice->getValue().isUninitialized();
78 })) {
79 return;
80 }
81
82 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
83 if (!inferrable)
84 return setAllToEntryStates(results);
85
86 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
87 SmallVector<ConstantIntRanges> argRanges(
88 llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
89 return val->getValue().getValue();
90 }));
91
92 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
93 auto result = dyn_cast<OpResult>(Val&: v);
94 if (!result)
95 return;
96 assert(llvm::is_contained(op->getResults(), result));
97
98 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
99 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
100 IntegerValueRange oldRange = lattice->getValue();
101
102 ChangeResult changed = lattice->join(IntegerValueRange{attrs});
103
104 // Catch loop results with loop variant bounds and conservatively make
105 // them [-inf, inf] so we don't circle around infinitely often (because
106 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
107 // and often can't).
108 bool isYieldedResult = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) {
109 return op->hasTrait<OpTrait::IsTerminator>();
110 });
111 if (isYieldedResult && !oldRange.isUninitialized() &&
112 !(lattice->getValue() == oldRange)) {
113 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
114 changed |= lattice->join(IntegerValueRange::getMaxRange(value: v));
115 }
116 propagateIfChanged(lattice, changed);
117 };
118
119 inferrable.inferResultRanges(argRanges, joinCallback);
120}
121
122void IntegerRangeAnalysis::visitNonControlFlowArguments(
123 Operation *op, const RegionSuccessor &successor,
124 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
125 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
126 LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
127 // If the lattice on any operand is unitialized, bail out.
128 if (llvm::any_of(Range: op->getOperands(), P: [&](Value value) {
129 return getLatticeElementFor(op, value)->getValue().isUninitialized();
130 }))
131 return;
132 SmallVector<ConstantIntRanges> argRanges(
133 llvm::map_range(op->getOperands(), [&](Value value) {
134 return getLatticeElementFor(op, value)->getValue().getValue();
135 }));
136
137 auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
138 auto arg = dyn_cast<BlockArgument>(Val&: v);
139 if (!arg)
140 return;
141 if (!llvm::is_contained(Range: successor.getSuccessor()->getArguments(), Element: arg))
142 return;
143
144 LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
145 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
146 IntegerValueRange oldRange = lattice->getValue();
147
148 ChangeResult changed = lattice->join(IntegerValueRange{attrs});
149
150 // Catch loop results with loop variant bounds and conservatively make
151 // them [-inf, inf] so we don't circle around infinitely often (because
152 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
153 // and often can't).
154 bool isYieldedValue = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) {
155 return op->hasTrait<OpTrait::IsTerminator>();
156 });
157 if (isYieldedValue && !oldRange.isUninitialized() &&
158 !(lattice->getValue() == oldRange)) {
159 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
160 changed |= lattice->join(IntegerValueRange::getMaxRange(value: v));
161 }
162 propagateIfChanged(lattice, changed);
163 };
164
165 inferrable.inferResultRanges(argRanges, joinCallback);
166 return;
167 }
168
169 /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
170 /// on a LoopLikeInterface return the lower/upper bound for that result if
171 /// possible.
172 auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
173 Type boundType, bool getUpper) {
174 unsigned int width = ConstantIntRanges::getStorageBitwidth(type: boundType);
175 if (loopBound.has_value()) {
176 if (loopBound->is<Attribute>()) {
177 if (auto bound =
178 dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
179 return bound.getValue();
180 } else if (auto value = llvm::dyn_cast_if_present<Value>(Val&: *loopBound)) {
181 const IntegerValueRangeLattice *lattice =
182 getLatticeElementFor(op, value);
183 if (lattice != nullptr && !lattice->getValue().isUninitialized())
184 return getUpper ? lattice->getValue().getValue().smax()
185 : lattice->getValue().getValue().smin();
186 }
187 }
188 // Given the results of getConstant{Lower,Upper}Bound()
189 // or getConstantStep() on a LoopLikeInterface return the lower/upper
190 // bound
191 return getUpper ? APInt::getSignedMaxValue(width)
192 : APInt::getSignedMinValue(width);
193 };
194
195 // Infer bounds for loop arguments that have static bounds
196 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
197 std::optional<Value> iv = loop.getSingleInductionVar();
198 if (!iv) {
199 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
200 op, successor, argLattices, firstIndex);
201 }
202 std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
203 std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
204 std::optional<OpFoldResult> step = loop.getSingleStep();
205 APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
206 /*getUpper=*/false);
207 APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
208 /*getUpper=*/true);
209 // Assume positivity for uniscoverable steps by way of getUpper = true.
210 APInt stepVal =
211 getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
212
213 if (stepVal.isNegative()) {
214 std::swap(a&: min, b&: max);
215 } else {
216 // Correct the upper bound by subtracting 1 so that it becomes a <=
217 // bound, because loops do not generally include their upper bound.
218 max -= 1;
219 }
220
221 IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
222 auto ivRange = ConstantIntRanges::fromSigned(smin: min, smax: max);
223 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
224 return;
225 }
226
227 return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
228 op, successor, argLattices, firstIndex);
229}
230

source code of mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp