1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Index/IR/IndexOps.h"
10#include "mlir/Interfaces/InferIntRangeInterface.h"
11#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13#include "llvm/Support/Debug.h"
14#include <optional>
15
16#define DEBUG_TYPE "int-range-analysis"
17
18using namespace mlir;
19using namespace mlir::index;
20using namespace mlir::intrange;
21
22//===----------------------------------------------------------------------===//
23// Constants
24//===----------------------------------------------------------------------===//
25
26void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27 SetIntRangeFn setResultRange) {
28 const APInt &value = getValue();
29 setResultRange(getResult(), ConstantIntRanges::constant(value));
30}
31
32void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33 SetIntRangeFn setResultRange) {
34 bool value = getValue();
35 APInt asInt(/*numBits=*/1, value);
36 setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37}
38
39//===----------------------------------------------------------------------===//
40// Arithmec operations. All of these operations will have their results inferred
41// using both the 64-bit values and truncated 32-bit values of their inputs,
42// with the results being the union of those inferences, except where the
43// truncation of the 64-bit result is equal to the 32-bit result (at which time
44// we take the 64-bit result).
45//===----------------------------------------------------------------------===//
46
47void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
48 SetIntRangeFn setResultRange) {
49 setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
50}
51
52void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
53 SetIntRangeFn setResultRange) {
54 setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
55}
56
57void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58 SetIntRangeFn setResultRange) {
59 setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
60}
61
62void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
63 SetIntRangeFn setResultRange) {
64 setResultRange(getResult(),
65 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
66}
67
68void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
69 SetIntRangeFn setResultRange) {
70 setResultRange(getResult(),
71 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
72}
73
74void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
75 SetIntRangeFn setResultRange) {
76 setResultRange(getResult(),
77 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
78}
79
80void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
81 SetIntRangeFn setResultRange) {
82 setResultRange(getResult(),
83 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
84}
85
86void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87 SetIntRangeFn setResultRange) {
88 return setResultRange(
89 getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
90}
91
92void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
93 SetIntRangeFn setResultRange) {
94 setResultRange(getResult(),
95 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
96}
97
98void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
99 SetIntRangeFn setResultRange) {
100 setResultRange(getResult(),
101 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
102}
103
104void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
105 SetIntRangeFn setResultRange) {
106 setResultRange(getResult(),
107 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
108}
109
110void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
111 SetIntRangeFn setResultRange) {
112 setResultRange(getResult(),
113 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
114}
115
116void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
117 SetIntRangeFn setResultRange) {
118 setResultRange(getResult(),
119 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
120}
121
122void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
123 SetIntRangeFn setResultRange) {
124 setResultRange(getResult(),
125 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
126}
127
128void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129 SetIntRangeFn setResultRange) {
130 setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
131}
132
133void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
134 SetIntRangeFn setResultRange) {
135 setResultRange(getResult(),
136 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
137}
138
139void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
140 SetIntRangeFn setResultRange) {
141 setResultRange(getResult(),
142 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
143}
144
145void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
146 SetIntRangeFn setResultRange) {
147 setResultRange(getResult(),
148 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
149}
150
151void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152 SetIntRangeFn setResultRange) {
153 setResultRange(getResult(),
154 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
155}
156
157void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
158 SetIntRangeFn setResultRange) {
159 setResultRange(getResult(),
160 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
161}
162
163//===----------------------------------------------------------------------===//
164// Casts
165//===----------------------------------------------------------------------===//
166
167static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
168 unsigned srcWidth, unsigned destWidth,
169 bool isSigned) {
170 if (srcWidth < destWidth)
171 return isSigned ? extSIRange(range, destWidth)
172 : extUIRange(range, destWidth);
173 if (srcWidth > destWidth)
174 return truncRange(range, destWidth);
175 return range;
176}
177
178// When casting to `index`, we will take the union of the possible fixed-width
179// casts.
180static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
181 Type sourceType, Type destType,
182 bool isSigned) {
183 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType);
184 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType);
185 if (sourceType.isIndex())
186 return makeLikeDest(range, srcWidth, destWidth, isSigned);
187 // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
188 ConstantIntRanges storageRange =
189 makeLikeDest(range, srcWidth, destWidth, isSigned);
190 ConstantIntRanges minWidthRange =
191 makeLikeDest(range, srcWidth, destWidth: indexMinWidth, isSigned);
192 ConstantIntRanges minWidthExt = extRange(range: minWidthRange, destWidth);
193 ConstantIntRanges ret = storageRange.rangeUnion(other: minWidthExt);
194 return ret;
195}
196
197void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
198 SetIntRangeFn setResultRange) {
199 Type sourceType = getOperand().getType();
200 Type destType = getResult().getType();
201 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
202 /*isSigned=*/true));
203}
204
205void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
206 SetIntRangeFn setResultRange) {
207 Type sourceType = getOperand().getType();
208 Type destType = getResult().getType();
209 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
210 /*isSigned=*/false));
211}
212
213//===----------------------------------------------------------------------===//
214// CmpOp
215//===----------------------------------------------------------------------===//
216
217void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
218 SetIntRangeFn setResultRange) {
219 index::IndexCmpPredicate indexPred = getPred();
220 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
221 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
222
223 APInt min = APInt::getZero(1);
224 APInt max = APInt::getAllOnes(1);
225
226 std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
227
228 ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
229 rhsTrunc = truncRange(rhs, indexMinWidth);
230 std::optional<bool> truthValue32 =
231 intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
232
233 if (truthValue64 == truthValue32) {
234 if (truthValue64.has_value() && *truthValue64)
235 min = max;
236 else if (truthValue64.has_value() && !(*truthValue64))
237 max = min;
238 }
239 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
240}
241
242//===----------------------------------------------------------------------===//
243// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
244//===----------------------------------------------------------------------===//
245
246void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
247 SetIntRangeFn setResultRange) {
248 unsigned storageWidth =
249 ConstantIntRanges::getStorageBitwidth(getResult().getType());
250 APInt min(/*numBits=*/storageWidth, indexMinWidth);
251 APInt max(/*numBits=*/storageWidth, indexMaxWidth);
252 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
253}
254

source code of mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp