1 | //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
10 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
11 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
12 | |
13 | #include "llvm/Support/Debug.h" |
14 | #include <optional> |
15 | |
16 | #define DEBUG_TYPE "int-range-analysis" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::index; |
20 | using namespace mlir::intrange; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Constants |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
27 | SetIntRangeFn setResultRange) { |
28 | const APInt &value = getValue(); |
29 | setResultRange(getResult(), ConstantIntRanges::constant(value)); |
30 | } |
31 | |
32 | void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
33 | SetIntRangeFn setResultRange) { |
34 | bool value = getValue(); |
35 | APInt asInt(/*numBits=*/1, value); |
36 | setResultRange(getResult(), ConstantIntRanges::constant(asInt)); |
37 | } |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // Arithmec operations. All of these operations will have their results inferred |
41 | // using both the 64-bit values and truncated 32-bit values of their inputs, |
42 | // with the results being the union of those inferences, except where the |
43 | // truncation of the 64-bit result is equal to the 32-bit result (at which time |
44 | // we take the 64-bit result). |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | // Some arithmetic inference functions allow specifying special overflow / wrap |
48 | // behavior. We do not require this for the IndexOps and use this helper to call |
49 | // the inference function without any `OverflowFlags`. |
50 | static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)> |
51 | inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) { |
52 | return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) { |
53 | return inferWithOvfFn(argRanges, OverflowFlags::None); |
54 | }; |
55 | } |
56 | |
57 | void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
58 | SetIntRangeFn setResultRange) { |
59 | setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd), |
60 | argRanges, CmpMode::Both)); |
61 | } |
62 | |
63 | void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
64 | SetIntRangeFn setResultRange) { |
65 | setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub), |
66 | argRanges, CmpMode::Both)); |
67 | } |
68 | |
69 | void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
70 | SetIntRangeFn setResultRange) { |
71 | setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul), |
72 | argRanges, CmpMode::Both)); |
73 | } |
74 | |
75 | void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
76 | SetIntRangeFn setResultRange) { |
77 | setResultRange(getResult(), |
78 | inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); |
79 | } |
80 | |
81 | void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
82 | SetIntRangeFn setResultRange) { |
83 | setResultRange(getResult(), |
84 | inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); |
85 | } |
86 | |
87 | void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
88 | SetIntRangeFn setResultRange) { |
89 | setResultRange(getResult(), |
90 | inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); |
91 | } |
92 | |
93 | void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
94 | SetIntRangeFn setResultRange) { |
95 | setResultRange(getResult(), |
96 | inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); |
97 | } |
98 | |
99 | void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
100 | SetIntRangeFn setResultRange) { |
101 | return setResultRange( |
102 | getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); |
103 | } |
104 | |
105 | void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
106 | SetIntRangeFn setResultRange) { |
107 | setResultRange(getResult(), |
108 | inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); |
109 | } |
110 | |
111 | void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
112 | SetIntRangeFn setResultRange) { |
113 | setResultRange(getResult(), |
114 | inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); |
115 | } |
116 | |
117 | void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
118 | SetIntRangeFn setResultRange) { |
119 | setResultRange(getResult(), |
120 | inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); |
121 | } |
122 | |
123 | void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
124 | SetIntRangeFn setResultRange) { |
125 | setResultRange(getResult(), |
126 | inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); |
127 | } |
128 | |
129 | void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
130 | SetIntRangeFn setResultRange) { |
131 | setResultRange(getResult(), |
132 | inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); |
133 | } |
134 | |
135 | void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
136 | SetIntRangeFn setResultRange) { |
137 | setResultRange(getResult(), |
138 | inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); |
139 | } |
140 | |
141 | void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
142 | SetIntRangeFn setResultRange) { |
143 | setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl), |
144 | argRanges, CmpMode::Both)); |
145 | } |
146 | |
147 | void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
148 | SetIntRangeFn setResultRange) { |
149 | setResultRange(getResult(), |
150 | inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); |
151 | } |
152 | |
153 | void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
154 | SetIntRangeFn setResultRange) { |
155 | setResultRange(getResult(), |
156 | inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); |
157 | } |
158 | |
159 | void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
160 | SetIntRangeFn setResultRange) { |
161 | setResultRange(getResult(), |
162 | inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); |
163 | } |
164 | |
165 | void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
166 | SetIntRangeFn setResultRange) { |
167 | setResultRange(getResult(), |
168 | inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); |
169 | } |
170 | |
171 | void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
172 | SetIntRangeFn setResultRange) { |
173 | setResultRange(getResult(), |
174 | inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); |
175 | } |
176 | |
177 | //===----------------------------------------------------------------------===// |
178 | // Casts |
179 | //===----------------------------------------------------------------------===// |
180 | |
181 | static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, |
182 | unsigned srcWidth, unsigned destWidth, |
183 | bool isSigned) { |
184 | if (srcWidth < destWidth) |
185 | return isSigned ? extSIRange(range, destWidth) |
186 | : extUIRange(range, destWidth); |
187 | if (srcWidth > destWidth) |
188 | return truncRange(range, destWidth); |
189 | return range; |
190 | } |
191 | |
192 | // When casting to `index`, we will take the union of the possible fixed-width |
193 | // casts. |
194 | static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, |
195 | Type sourceType, Type destType, |
196 | bool isSigned) { |
197 | unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType); |
198 | unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType); |
199 | if (sourceType.isIndex()) |
200 | return makeLikeDest(range, srcWidth, destWidth, isSigned); |
201 | // We are casting to indexs, so use the union of the 32-bit and 64-bit casts |
202 | ConstantIntRanges storageRange = |
203 | makeLikeDest(range, srcWidth, destWidth, isSigned); |
204 | ConstantIntRanges minWidthRange = |
205 | makeLikeDest(range, srcWidth, destWidth: indexMinWidth, isSigned); |
206 | ConstantIntRanges minWidthExt = extRange(range: minWidthRange, destWidth); |
207 | ConstantIntRanges ret = storageRange.rangeUnion(other: minWidthExt); |
208 | return ret; |
209 | } |
210 | |
211 | void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
212 | SetIntRangeFn setResultRange) { |
213 | Type sourceType = getOperand().getType(); |
214 | Type destType = getResult().getType(); |
215 | setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, |
216 | /*isSigned=*/true)); |
217 | } |
218 | |
219 | void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
220 | SetIntRangeFn setResultRange) { |
221 | Type sourceType = getOperand().getType(); |
222 | Type destType = getResult().getType(); |
223 | setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, |
224 | /*isSigned=*/false)); |
225 | } |
226 | |
227 | //===----------------------------------------------------------------------===// |
228 | // CmpOp |
229 | //===----------------------------------------------------------------------===// |
230 | |
231 | void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
232 | SetIntRangeFn setResultRange) { |
233 | index::IndexCmpPredicate indexPred = getPred(); |
234 | intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred); |
235 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
236 | |
237 | APInt min = APInt::getZero(1); |
238 | APInt max = APInt::getAllOnes(1); |
239 | |
240 | std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs); |
241 | |
242 | ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), |
243 | rhsTrunc = truncRange(rhs, indexMinWidth); |
244 | std::optional<bool> truthValue32 = |
245 | intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); |
246 | |
247 | if (truthValue64 == truthValue32) { |
248 | if (truthValue64.has_value() && *truthValue64) |
249 | min = max; |
250 | else if (truthValue64.has_value() && !(*truthValue64)) |
251 | max = min; |
252 | } |
253 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); |
254 | } |
255 | |
256 | //===----------------------------------------------------------------------===// |
257 | // SizeOf, which is bounded between the two supported bitwidth (32 and 64). |
258 | //===----------------------------------------------------------------------===// |
259 | |
260 | void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
261 | SetIntRangeFn setResultRange) { |
262 | unsigned storageWidth = |
263 | ConstantIntRanges::getStorageBitwidth(getResult().getType()); |
264 | APInt min(/*numBits=*/storageWidth, indexMinWidth); |
265 | APInt max(/*numBits=*/storageWidth, indexMaxWidth); |
266 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); |
267 | } |
268 | |