1//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/InferIntRangeInterface.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
12#include <optional>
13
14using namespace mlir;
15
16bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
17 return umin().getBitWidth() == other.umin().getBitWidth() &&
18 umin() == other.umin() && umax() == other.umax() &&
19 smin() == other.smin() && smax() == other.smax();
20}
21
22const APInt &ConstantIntRanges::umin() const { return uminVal; }
23
24const APInt &ConstantIntRanges::umax() const { return umaxVal; }
25
26const APInt &ConstantIntRanges::smin() const { return sminVal; }
27
28const APInt &ConstantIntRanges::smax() const { return smaxVal; }
29
30unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
31 if (type.isIndex())
32 return IndexType::kInternalStorageBitWidth;
33 if (auto integerType = dyn_cast<IntegerType>(type))
34 return integerType.getWidth();
35 // Non-integer types have their bounds stored in width 0 `APInt`s.
36 return 0;
37}
38
39ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
40 return fromUnsigned(umin: APInt::getZero(numBits: bitwidth), umax: APInt::getMaxValue(numBits: bitwidth));
41}
42
43ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
44 return {value, value, value, value};
45}
46
47ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
48 bool isSigned) {
49 if (isSigned)
50 return fromSigned(smin: min, smax: max);
51 return fromUnsigned(umin: min, umax: max);
52}
53
54ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
55 const APInt &smax) {
56 unsigned int width = smin.getBitWidth();
57 APInt umin, umax;
58 if (smin.isNonNegative() == smax.isNonNegative()) {
59 umin = smin.ult(RHS: smax) ? smin : smax;
60 umax = smin.ugt(RHS: smax) ? smin : smax;
61 } else {
62 umin = APInt::getMinValue(numBits: width);
63 umax = APInt::getMaxValue(numBits: width);
64 }
65 return {umin, umax, smin, smax};
66}
67
68ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
69 const APInt &umax) {
70 unsigned int width = umin.getBitWidth();
71 APInt smin, smax;
72 if (umin.isNonNegative() == umax.isNonNegative()) {
73 smin = umin.slt(RHS: umax) ? umin : umax;
74 smax = umin.sgt(RHS: umax) ? umin : umax;
75 } else {
76 smin = APInt::getSignedMinValue(numBits: width);
77 smax = APInt::getSignedMaxValue(numBits: width);
78 }
79 return {umin, umax, smin, smax};
80}
81
82ConstantIntRanges
83ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
84 // "Not an integer" poisons everything and also cannot be fed to comparison
85 // operators.
86 if (umin().getBitWidth() == 0)
87 return *this;
88 if (other.umin().getBitWidth() == 0)
89 return other;
90
91 const APInt &uminUnion = umin().ult(RHS: other.umin()) ? umin() : other.umin();
92 const APInt &umaxUnion = umax().ugt(RHS: other.umax()) ? umax() : other.umax();
93 const APInt &sminUnion = smin().slt(RHS: other.smin()) ? smin() : other.smin();
94 const APInt &smaxUnion = smax().sgt(RHS: other.smax()) ? smax() : other.smax();
95
96 return {uminUnion, umaxUnion, sminUnion, smaxUnion};
97}
98
99ConstantIntRanges
100ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
101 // "Not an integer" poisons everything and also cannot be fed to comparison
102 // operators.
103 if (umin().getBitWidth() == 0)
104 return *this;
105 if (other.umin().getBitWidth() == 0)
106 return other;
107
108 const APInt &uminIntersect = umin().ugt(RHS: other.umin()) ? umin() : other.umin();
109 const APInt &umaxIntersect = umax().ult(RHS: other.umax()) ? umax() : other.umax();
110 const APInt &sminIntersect = smin().sgt(RHS: other.smin()) ? smin() : other.smin();
111 const APInt &smaxIntersect = smax().slt(RHS: other.smax()) ? smax() : other.smax();
112
113 return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
114}
115
116std::optional<APInt> ConstantIntRanges::getConstantValue() const {
117 // Note: we need to exclude the trivially-equal width 0 values here.
118 if (umin() == umax() && umin().getBitWidth() != 0)
119 return umin();
120 if (smin() == smax() && smin().getBitWidth() != 0)
121 return smin();
122 return std::nullopt;
123}
124
125raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
126 return os << "unsigned : [" << range.umin() << ", " << range.umax()
127 << "] signed : [" << range.smin() << ", " << range.smax() << "]";
128}
129

source code of mlir/lib/Interfaces/InferIntRangeInterface.cpp