1//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/InferIntRangeInterface.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/IR/TypeUtilities.h"
12#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
13#include <optional>
14
15using namespace mlir;
16
17bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
18 return umin().getBitWidth() == other.umin().getBitWidth() &&
19 umin() == other.umin() && umax() == other.umax() &&
20 smin() == other.smin() && smax() == other.smax();
21}
22
23const APInt &ConstantIntRanges::umin() const { return uminVal; }
24
25const APInt &ConstantIntRanges::umax() const { return umaxVal; }
26
27const APInt &ConstantIntRanges::smin() const { return sminVal; }
28
29const APInt &ConstantIntRanges::smax() const { return smaxVal; }
30
31unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
32 type = getElementTypeOrSelf(type);
33 if (type.isIndex())
34 return IndexType::kInternalStorageBitWidth;
35 if (auto integerType = dyn_cast<IntegerType>(type))
36 return integerType.getWidth();
37 // Non-integer types have their bounds stored in width 0 `APInt`s.
38 return 0;
39}
40
41ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
42 return fromUnsigned(umin: APInt::getZero(numBits: bitwidth), umax: APInt::getMaxValue(numBits: bitwidth));
43}
44
45ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
46 return {value, value, value, value};
47}
48
49ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
50 bool isSigned) {
51 if (isSigned)
52 return fromSigned(smin: min, smax: max);
53 return fromUnsigned(umin: min, umax: max);
54}
55
56ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
57 const APInt &smax) {
58 unsigned int width = smin.getBitWidth();
59 APInt umin, umax;
60 if (smin.isNonNegative() == smax.isNonNegative()) {
61 umin = smin.ult(RHS: smax) ? smin : smax;
62 umax = smin.ugt(RHS: smax) ? smin : smax;
63 } else {
64 umin = APInt::getMinValue(numBits: width);
65 umax = APInt::getMaxValue(numBits: width);
66 }
67 return {umin, umax, smin, smax};
68}
69
70ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
71 const APInt &umax) {
72 unsigned int width = umin.getBitWidth();
73 APInt smin, smax;
74 if (umin.isNonNegative() == umax.isNonNegative()) {
75 smin = umin.slt(RHS: umax) ? umin : umax;
76 smax = umin.sgt(RHS: umax) ? umin : umax;
77 } else {
78 smin = APInt::getSignedMinValue(numBits: width);
79 smax = APInt::getSignedMaxValue(numBits: width);
80 }
81 return {umin, umax, smin, smax};
82}
83
84ConstantIntRanges
85ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
86 // "Not an integer" poisons everything and also cannot be fed to comparison
87 // operators.
88 if (umin().getBitWidth() == 0)
89 return *this;
90 if (other.umin().getBitWidth() == 0)
91 return other;
92
93 const APInt &uminUnion = umin().ult(RHS: other.umin()) ? umin() : other.umin();
94 const APInt &umaxUnion = umax().ugt(RHS: other.umax()) ? umax() : other.umax();
95 const APInt &sminUnion = smin().slt(RHS: other.smin()) ? smin() : other.smin();
96 const APInt &smaxUnion = smax().sgt(RHS: other.smax()) ? smax() : other.smax();
97
98 return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99}
100
101ConstantIntRanges
102ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
103 // "Not an integer" poisons everything and also cannot be fed to comparison
104 // operators.
105 if (umin().getBitWidth() == 0)
106 return *this;
107 if (other.umin().getBitWidth() == 0)
108 return other;
109
110 const APInt &uminIntersect = umin().ugt(RHS: other.umin()) ? umin() : other.umin();
111 const APInt &umaxIntersect = umax().ult(RHS: other.umax()) ? umax() : other.umax();
112 const APInt &sminIntersect = smin().sgt(RHS: other.smin()) ? smin() : other.smin();
113 const APInt &smaxIntersect = smax().slt(RHS: other.smax()) ? smax() : other.smax();
114
115 return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116}
117
118std::optional<APInt> ConstantIntRanges::getConstantValue() const {
119 // Note: we need to exclude the trivially-equal width 0 values here.
120 if (umin() == umax() && umin().getBitWidth() != 0)
121 return umin();
122 if (smin() == smax() && smin().getBitWidth() != 0)
123 return smin();
124 return std::nullopt;
125}
126
127raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128 os << "unsigned : [";
129 range.umin().print(OS&: os, /*isSigned*/ false);
130 os << ", ";
131 range.umax().print(OS&: os, /*isSigned*/ false);
132 return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
133}
134
135IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
136 unsigned width = ConstantIntRanges::getStorageBitwidth(type: value.getType());
137 APInt umin = APInt::getMinValue(numBits: width);
138 APInt umax = APInt::getMaxValue(numBits: width);
139 APInt smin = width != 0 ? APInt::getSignedMinValue(numBits: width) : umin;
140 APInt smax = width != 0 ? APInt::getSignedMaxValue(numBits: width) : umax;
141 return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
142}
143
144raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
145 range.print(os);
146 return os;
147}
148
149void mlir::intrange::detail::defaultInferResultRanges(
150 InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
151 SetIntLatticeFn setResultRanges) {
152 llvm::SmallVector<ConstantIntRanges> unpacked;
153 unpacked.reserve(N: argRanges.size());
154
155 for (const IntegerValueRange &range : argRanges) {
156 if (range.isUninitialized())
157 return;
158 unpacked.push_back(Elt: range.getValue());
159 }
160
161 interface.inferResultRanges(
162 unpacked,
163 [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
164 setResultRanges(value, IntegerValueRange{argRanges});
165 });
166}
167
168void mlir::intrange::detail::defaultInferResultRangesFromOptional(
169 InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
170 SetIntRangeFn setResultRanges) {
171 auto ranges = llvm::to_vector_of<IntegerValueRange>(Range&: argRanges);
172 interface.inferResultRangesFromOptional(
173 ranges,
174 [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
175 if (!argRanges.isUninitialized())
176 setResultRanges(value, argRanges.getValue());
177 });
178}
179

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Interfaces/InferIntRangeInterface.cpp