1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Index/IR/IndexOps.h"
10#include "mlir/Interfaces/InferIntRangeInterface.h"
11#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13#include "llvm/Support/Debug.h"
14#include <optional>
15
16#define DEBUG_TYPE "int-range-analysis"
17
18using namespace mlir;
19using namespace mlir::index;
20using namespace mlir::intrange;
21
22//===----------------------------------------------------------------------===//
23// Constants
24//===----------------------------------------------------------------------===//
25
26void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27 SetIntRangeFn setResultRange) {
28 const APInt &value = getValue();
29 setResultRange(getResult(), ConstantIntRanges::constant(value));
30}
31
32void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33 SetIntRangeFn setResultRange) {
34 bool value = getValue();
35 APInt asInt(/*numBits=*/1, value);
36 setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37}
38
39//===----------------------------------------------------------------------===//
40// Arithmec operations. All of these operations will have their results inferred
41// using both the 64-bit values and truncated 32-bit values of their inputs,
42// with the results being the union of those inferences, except where the
43// truncation of the 64-bit result is equal to the 32-bit result (at which time
44// we take the 64-bit result).
45//===----------------------------------------------------------------------===//
46
47// Some arithmetic inference functions allow specifying special overflow / wrap
48// behavior. We do not require this for the IndexOps and use this helper to call
49// the inference function without any `OverflowFlags`.
50static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
51inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
52 return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
53 return inferWithOvfFn(argRanges, OverflowFlags::None);
54 };
55}
56
57void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58 SetIntRangeFn setResultRange) {
59 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
60 argRanges, CmpMode::Both));
61}
62
63void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
64 SetIntRangeFn setResultRange) {
65 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
66 argRanges, CmpMode::Both));
67}
68
69void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
70 SetIntRangeFn setResultRange) {
71 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
72 argRanges, CmpMode::Both));
73}
74
75void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76 SetIntRangeFn setResultRange) {
77 setResultRange(getResult(),
78 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
79}
80
81void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
82 SetIntRangeFn setResultRange) {
83 setResultRange(getResult(),
84 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
85}
86
87void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88 SetIntRangeFn setResultRange) {
89 setResultRange(getResult(),
90 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
91}
92
93void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
94 SetIntRangeFn setResultRange) {
95 setResultRange(getResult(),
96 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
97}
98
99void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
100 SetIntRangeFn setResultRange) {
101 return setResultRange(
102 getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
103}
104
105void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106 SetIntRangeFn setResultRange) {
107 setResultRange(getResult(),
108 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
109}
110
111void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112 SetIntRangeFn setResultRange) {
113 setResultRange(getResult(),
114 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
115}
116
117void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
118 SetIntRangeFn setResultRange) {
119 setResultRange(getResult(),
120 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
121}
122
123void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
124 SetIntRangeFn setResultRange) {
125 setResultRange(getResult(),
126 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
127}
128
129void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130 SetIntRangeFn setResultRange) {
131 setResultRange(getResult(),
132 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
133}
134
135void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
136 SetIntRangeFn setResultRange) {
137 setResultRange(getResult(),
138 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
139}
140
141void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142 SetIntRangeFn setResultRange) {
143 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
144 argRanges, CmpMode::Both));
145}
146
147void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148 SetIntRangeFn setResultRange) {
149 setResultRange(getResult(),
150 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
151}
152
153void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
154 SetIntRangeFn setResultRange) {
155 setResultRange(getResult(),
156 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
157}
158
159void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160 SetIntRangeFn setResultRange) {
161 setResultRange(getResult(),
162 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
163}
164
165void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166 SetIntRangeFn setResultRange) {
167 setResultRange(getResult(),
168 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
169}
170
171void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
172 SetIntRangeFn setResultRange) {
173 setResultRange(getResult(),
174 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
175}
176
177//===----------------------------------------------------------------------===//
178// Casts
179//===----------------------------------------------------------------------===//
180
181static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
182 unsigned srcWidth, unsigned destWidth,
183 bool isSigned) {
184 if (srcWidth < destWidth)
185 return isSigned ? extSIRange(range, destWidth)
186 : extUIRange(range, destWidth);
187 if (srcWidth > destWidth)
188 return truncRange(range, destWidth);
189 return range;
190}
191
192// When casting to `index`, we will take the union of the possible fixed-width
193// casts.
194static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
195 Type sourceType, Type destType,
196 bool isSigned) {
197 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType);
198 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType);
199 if (sourceType.isIndex())
200 return makeLikeDest(range, srcWidth, destWidth, isSigned);
201 // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
202 ConstantIntRanges storageRange =
203 makeLikeDest(range, srcWidth, destWidth, isSigned);
204 ConstantIntRanges minWidthRange =
205 makeLikeDest(range, srcWidth, destWidth: indexMinWidth, isSigned);
206 ConstantIntRanges minWidthExt = extRange(range: minWidthRange, destWidth);
207 ConstantIntRanges ret = storageRange.rangeUnion(other: minWidthExt);
208 return ret;
209}
210
211void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
212 SetIntRangeFn setResultRange) {
213 Type sourceType = getOperand().getType();
214 Type destType = getResult().getType();
215 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
216 /*isSigned=*/true));
217}
218
219void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
220 SetIntRangeFn setResultRange) {
221 Type sourceType = getOperand().getType();
222 Type destType = getResult().getType();
223 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
224 /*isSigned=*/false));
225}
226
227//===----------------------------------------------------------------------===//
228// CmpOp
229//===----------------------------------------------------------------------===//
230
231void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
232 SetIntRangeFn setResultRange) {
233 index::IndexCmpPredicate indexPred = getPred();
234 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
235 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
236
237 APInt min = APInt::getZero(1);
238 APInt max = APInt::getAllOnes(1);
239
240 std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
241
242 ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
243 rhsTrunc = truncRange(rhs, indexMinWidth);
244 std::optional<bool> truthValue32 =
245 intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
246
247 if (truthValue64 == truthValue32) {
248 if (truthValue64.has_value() && *truthValue64)
249 min = max;
250 else if (truthValue64.has_value() && !(*truthValue64))
251 max = min;
252 }
253 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
254}
255
256//===----------------------------------------------------------------------===//
257// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
258//===----------------------------------------------------------------------===//
259
260void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
261 SetIntRangeFn setResultRange) {
262 unsigned storageWidth =
263 ConstantIntRanges::getStorageBitwidth(getResult().getType());
264 APInt min(/*numBits=*/storageWidth, indexMinWidth);
265 APInt max(/*numBits=*/storageWidth, indexMaxWidth);
266 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
267}
268

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp