| 1 | //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the dataflow analysis class for integer range inference |
| 10 | // which is used in transformations over the `arith` dialect such as |
| 11 | // branch elimination or signed->unsigned rewriting |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
| 16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| 17 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| 18 | #include "mlir/Analysis/DataFlowFramework.h" |
| 19 | #include "mlir/IR/BuiltinAttributes.h" |
| 20 | #include "mlir/IR/Dialect.h" |
| 21 | #include "mlir/IR/OpDefinition.h" |
| 22 | #include "mlir/IR/TypeUtilities.h" |
| 23 | #include "mlir/IR/Value.h" |
| 24 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 25 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
| 26 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 27 | #include "mlir/Support/LLVM.h" |
| 28 | #include "llvm/ADT/STLExtras.h" |
| 29 | #include "llvm/Support/Casting.h" |
| 30 | #include "llvm/Support/Debug.h" |
| 31 | #include <cassert> |
| 32 | #include <optional> |
| 33 | #include <utility> |
| 34 | |
| 35 | #define DEBUG_TYPE "int-range-analysis" |
| 36 | |
| 37 | using namespace mlir; |
| 38 | using namespace mlir::dataflow; |
| 39 | |
| 40 | namespace mlir::dataflow { |
| 41 | LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { |
| 42 | auto *result = solver.lookupState<IntegerValueRangeLattice>(anchor: v); |
| 43 | if (!result || result->getValue().isUninitialized()) |
| 44 | return failure(); |
| 45 | const ConstantIntRanges &range = result->getValue().getValue(); |
| 46 | return success(IsSuccess: range.smin().isNonNegative()); |
| 47 | } |
| 48 | |
| 49 | LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { |
| 50 | auto nonNegativePred = [&solver](Value v) -> bool { |
| 51 | return succeeded(Result: staticallyNonNegative(solver, v)); |
| 52 | }; |
| 53 | return success(IsSuccess: llvm::all_of(Range: op->getOperands(), P: nonNegativePred) && |
| 54 | llvm::all_of(Range: op->getResults(), P: nonNegativePred)); |
| 55 | } |
| 56 | } // namespace mlir::dataflow |
| 57 | |
| 58 | void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { |
| 59 | Lattice::onUpdate(solver); |
| 60 | |
| 61 | // If the integer range can be narrowed to a constant, update the constant |
| 62 | // value of the SSA value. |
| 63 | std::optional<APInt> constant = getValue().getValue().getConstantValue(); |
| 64 | auto value = cast<Value>(Val: anchor); |
| 65 | auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(anchor: value); |
| 66 | if (!constant) |
| 67 | return solver->propagateIfChanged( |
| 68 | state: cv, changed: cv->join(rhs: ConstantValue::getUnknownConstant())); |
| 69 | |
| 70 | Dialect *dialect; |
| 71 | if (auto *parent = value.getDefiningOp()) |
| 72 | dialect = parent->getDialect(); |
| 73 | else |
| 74 | dialect = value.getParentBlock()->getParentOp()->getDialect(); |
| 75 | |
| 76 | Type type = getElementTypeOrSelf(val: value); |
| 77 | solver->propagateIfChanged( |
| 78 | cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); |
| 79 | } |
| 80 | |
| 81 | LogicalResult IntegerRangeAnalysis::visitOperation( |
| 82 | Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, |
| 83 | ArrayRef<IntegerValueRangeLattice *> results) { |
| 84 | auto inferrable = dyn_cast<InferIntRangeInterface>(op); |
| 85 | if (!inferrable) { |
| 86 | setAllToEntryStates(results); |
| 87 | return success(); |
| 88 | } |
| 89 | |
| 90 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
| 91 | auto argRanges = llvm::map_to_vector( |
| 92 | C&: operands, F: [](const IntegerValueRangeLattice *lattice) { |
| 93 | return lattice->getValue(); |
| 94 | }); |
| 95 | |
| 96 | auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
| 97 | auto result = dyn_cast<OpResult>(Val&: v); |
| 98 | if (!result) |
| 99 | return; |
| 100 | assert(llvm::is_contained(op->getResults(), result)); |
| 101 | |
| 102 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
| 103 | IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; |
| 104 | IntegerValueRange oldRange = lattice->getValue(); |
| 105 | |
| 106 | ChangeResult changed = lattice->join(rhs: attrs); |
| 107 | |
| 108 | // Catch loop results with loop variant bounds and conservatively make |
| 109 | // them [-inf, inf] so we don't circle around infinitely often (because |
| 110 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
| 111 | // and often can't). |
| 112 | bool isYieldedResult = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
| 113 | return op->hasTrait<OpTrait::IsTerminator>(); |
| 114 | }); |
| 115 | if (isYieldedResult && !oldRange.isUninitialized() && |
| 116 | !(lattice->getValue() == oldRange)) { |
| 117 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
| 118 | changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v)); |
| 119 | } |
| 120 | propagateIfChanged(lattice, changed); |
| 121 | }; |
| 122 | |
| 123 | inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
| 124 | return success(); |
| 125 | } |
| 126 | |
| 127 | void IntegerRangeAnalysis::visitNonControlFlowArguments( |
| 128 | Operation *op, const RegionSuccessor &successor, |
| 129 | ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { |
| 130 | if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { |
| 131 | LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n" ); |
| 132 | |
| 133 | auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { |
| 134 | return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); |
| 135 | }); |
| 136 | |
| 137 | auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
| 138 | auto arg = dyn_cast<BlockArgument>(Val&: v); |
| 139 | if (!arg) |
| 140 | return; |
| 141 | if (!llvm::is_contained(Range: successor.getSuccessor()->getArguments(), Element: arg)) |
| 142 | return; |
| 143 | |
| 144 | LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n" ); |
| 145 | IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; |
| 146 | IntegerValueRange oldRange = lattice->getValue(); |
| 147 | |
| 148 | ChangeResult changed = lattice->join(rhs: attrs); |
| 149 | |
| 150 | // Catch loop results with loop variant bounds and conservatively make |
| 151 | // them [-inf, inf] so we don't circle around infinitely often (because |
| 152 | // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
| 153 | // and often can't). |
| 154 | bool isYieldedValue = llvm::any_of(Range: v.getUsers(), P: [](Operation *op) { |
| 155 | return op->hasTrait<OpTrait::IsTerminator>(); |
| 156 | }); |
| 157 | if (isYieldedValue && !oldRange.isUninitialized() && |
| 158 | !(lattice->getValue() == oldRange)) { |
| 159 | LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n" ); |
| 160 | changed |= lattice->join(rhs: IntegerValueRange::getMaxRange(value: v)); |
| 161 | } |
| 162 | propagateIfChanged(lattice, changed); |
| 163 | }; |
| 164 | |
| 165 | inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
| 166 | return; |
| 167 | } |
| 168 | |
| 169 | /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() |
| 170 | /// on a LoopLikeInterface return the lower/upper bound for that result if |
| 171 | /// possible. |
| 172 | auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, |
| 173 | Type boundType, Block *block, bool getUpper) { |
| 174 | unsigned int width = ConstantIntRanges::getStorageBitwidth(type: boundType); |
| 175 | if (loopBound.has_value()) { |
| 176 | if (auto attr = dyn_cast<Attribute>(Val&: *loopBound)) { |
| 177 | if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) |
| 178 | return bound.getValue(); |
| 179 | } else if (auto value = llvm::dyn_cast_if_present<Value>(Val&: *loopBound)) { |
| 180 | const IntegerValueRangeLattice *lattice = |
| 181 | getLatticeElementFor(getProgramPointBefore(block), value); |
| 182 | if (lattice != nullptr && !lattice->getValue().isUninitialized()) |
| 183 | return getUpper ? lattice->getValue().getValue().smax() |
| 184 | : lattice->getValue().getValue().smin(); |
| 185 | } |
| 186 | } |
| 187 | // Given the results of getConstant{Lower,Upper}Bound() |
| 188 | // or getConstantStep() on a LoopLikeInterface return the lower/upper |
| 189 | // bound |
| 190 | return getUpper ? APInt::getSignedMaxValue(width) |
| 191 | : APInt::getSignedMinValue(width); |
| 192 | }; |
| 193 | |
| 194 | // Infer bounds for loop arguments that have static bounds |
| 195 | if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { |
| 196 | std::optional<Value> iv = loop.getSingleInductionVar(); |
| 197 | if (!iv) { |
| 198 | return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( |
| 199 | op, successor, argLattices, firstIndex); |
| 200 | } |
| 201 | Block *block = iv->getParentBlock(); |
| 202 | std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); |
| 203 | std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); |
| 204 | std::optional<OpFoldResult> step = loop.getSingleStep(); |
| 205 | APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, |
| 206 | /*getUpper=*/false); |
| 207 | APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, |
| 208 | /*getUpper=*/true); |
| 209 | // Assume positivity for uniscoverable steps by way of getUpper = true. |
| 210 | APInt stepVal = |
| 211 | getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); |
| 212 | |
| 213 | if (stepVal.isNegative()) { |
| 214 | std::swap(a&: min, b&: max); |
| 215 | } else { |
| 216 | // Correct the upper bound by subtracting 1 so that it becomes a <= |
| 217 | // bound, because loops do not generally include their upper bound. |
| 218 | max -= 1; |
| 219 | } |
| 220 | |
| 221 | // If we infer the lower bound to be larger than the upper bound, the |
| 222 | // resulting range is meaningless and should not be used in further |
| 223 | // inferences. |
| 224 | if (max.sge(RHS: min)) { |
| 225 | IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); |
| 226 | auto ivRange = ConstantIntRanges::fromSigned(smin: min, smax: max); |
| 227 | propagateIfChanged(ivEntry, ivEntry->join(rhs: IntegerValueRange{ivRange})); |
| 228 | } |
| 229 | return; |
| 230 | } |
| 231 | |
| 232 | return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( |
| 233 | op, successor, argLattices, firstIndex); |
| 234 | } |
| 235 | |