1 | //===- InferIntRangeCommon.cpp - Inference for common ops ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains implementations of range inference for operations that are |
10 | // common to both the `arith` and `index` dialects to facilitate reuse. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
15 | |
16 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
17 | |
18 | #include "llvm/ADT/ArrayRef.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | |
21 | #include "llvm/Support/Debug.h" |
22 | |
23 | #include <iterator> |
24 | #include <optional> |
25 | |
26 | using namespace mlir; |
27 | |
28 | #define DEBUG_TYPE "int-range-analysis" |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // General utilities |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Function that evaluates the result of doing something on arithmetic |
35 | /// constants and returns std::nullopt on overflow. |
36 | using ConstArithFn = |
37 | function_ref<std::optional<APInt>(const APInt &, const APInt &)>; |
38 | |
39 | /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, |
40 | /// If either computation overflows, make the result unbounded. |
41 | static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, |
42 | const APInt &minRight, |
43 | const APInt &maxLeft, |
44 | const APInt &maxRight, bool isSigned) { |
45 | std::optional<APInt> maybeMin = op(minLeft, minRight); |
46 | std::optional<APInt> maybeMax = op(maxLeft, maxRight); |
47 | if (maybeMin && maybeMax) |
48 | return ConstantIntRanges::range(min: *maybeMin, max: *maybeMax, isSigned); |
49 | return ConstantIntRanges::maxRange(bitwidth: minLeft.getBitWidth()); |
50 | } |
51 | |
52 | /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, |
53 | /// ignoring unbounded values. Returns the maximal range if `op` overflows. |
54 | static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, |
55 | ArrayRef<APInt> rhs, bool isSigned) { |
56 | unsigned width = lhs[0].getBitWidth(); |
57 | APInt min = |
58 | isSigned ? APInt::getSignedMaxValue(numBits: width) : APInt::getMaxValue(numBits: width); |
59 | APInt max = |
60 | isSigned ? APInt::getSignedMinValue(numBits: width) : APInt::getZero(numBits: width); |
61 | for (const APInt &left : lhs) { |
62 | for (const APInt &right : rhs) { |
63 | std::optional<APInt> maybeThisResult = op(left, right); |
64 | if (!maybeThisResult) |
65 | return ConstantIntRanges::maxRange(bitwidth: width); |
66 | APInt result = std::move(*maybeThisResult); |
67 | min = (isSigned ? result.slt(RHS: min) : result.ult(RHS: min)) ? result : min; |
68 | max = (isSigned ? result.sgt(RHS: max) : result.ugt(RHS: max)) ? result : max; |
69 | } |
70 | } |
71 | return ConstantIntRanges::range(min, max, isSigned); |
72 | } |
73 | |
74 | //===----------------------------------------------------------------------===// |
75 | // Ext, trunc, index op handling |
76 | //===----------------------------------------------------------------------===// |
77 | |
78 | ConstantIntRanges |
79 | mlir::intrange::inferIndexOp(InferRangeFn inferFn, |
80 | ArrayRef<ConstantIntRanges> argRanges, |
81 | intrange::CmpMode mode) { |
82 | ConstantIntRanges sixtyFour = inferFn(argRanges); |
83 | SmallVector<ConstantIntRanges, 2> truncated; |
84 | llvm::transform(Range&: argRanges, d_first: std::back_inserter(x&: truncated), |
85 | F: [](const ConstantIntRanges &range) { |
86 | return truncRange(range, /*destWidth=*/indexMinWidth); |
87 | }); |
88 | ConstantIntRanges thirtyTwo = inferFn(truncated); |
89 | ConstantIntRanges thirtyTwoAsSixtyFour = |
90 | extRange(range: thirtyTwo, /*destWidth=*/indexMaxWidth); |
91 | ConstantIntRanges sixtyFourAsThirtyTwo = |
92 | truncRange(range: sixtyFour, /*destWidth=*/indexMinWidth); |
93 | |
94 | LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour |
95 | << " 32-bit = " << thirtyTwo << "\n" ); |
96 | bool truncEqual = false; |
97 | switch (mode) { |
98 | case intrange::CmpMode::Both: |
99 | truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); |
100 | break; |
101 | case intrange::CmpMode::Signed: |
102 | truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && |
103 | thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); |
104 | break; |
105 | case intrange::CmpMode::Unsigned: |
106 | truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && |
107 | thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); |
108 | break; |
109 | } |
110 | if (truncEqual) |
111 | // Returing the 64-bit result preserves more information. |
112 | return sixtyFour; |
113 | ConstantIntRanges merged = sixtyFour.rangeUnion(other: thirtyTwoAsSixtyFour); |
114 | return merged; |
115 | } |
116 | |
117 | ConstantIntRanges mlir::intrange::(const ConstantIntRanges &range, |
118 | unsigned int destWidth) { |
119 | APInt umin = range.umin().zext(width: destWidth); |
120 | APInt umax = range.umax().zext(width: destWidth); |
121 | APInt smin = range.smin().sext(width: destWidth); |
122 | APInt smax = range.smax().sext(width: destWidth); |
123 | return {umin, umax, smin, smax}; |
124 | } |
125 | |
126 | ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, |
127 | unsigned destWidth) { |
128 | APInt umin = range.umin().zext(width: destWidth); |
129 | APInt umax = range.umax().zext(width: destWidth); |
130 | return ConstantIntRanges::fromUnsigned(umin, umax); |
131 | } |
132 | |
133 | ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, |
134 | unsigned destWidth) { |
135 | APInt smin = range.smin().sext(width: destWidth); |
136 | APInt smax = range.smax().sext(width: destWidth); |
137 | return ConstantIntRanges::fromSigned(smin, smax); |
138 | } |
139 | |
140 | ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, |
141 | unsigned int destWidth) { |
142 | // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], |
143 | // the range of the resulting value is not contiguous ind includes 0. |
144 | // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], |
145 | // but you can't truncate [255, 257] similarly. |
146 | bool hasUnsignedRollover = |
147 | range.umin().lshr(shiftAmt: destWidth) != range.umax().lshr(shiftAmt: destWidth); |
148 | APInt umin = hasUnsignedRollover ? APInt::getZero(numBits: destWidth) |
149 | : range.umin().trunc(width: destWidth); |
150 | APInt umax = hasUnsignedRollover ? APInt::getMaxValue(numBits: destWidth) |
151 | : range.umax().trunc(width: destWidth); |
152 | |
153 | // Signed post-truncation rollover will not occur when either: |
154 | // - The high parts of the min and max, plus the sign bit, are the same |
155 | // - The high halves + sign bit of the min and max are either all 1s or all 0s |
156 | // and you won't create a [positive, negative] range by truncating. |
157 | // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 |
158 | // but not [255, 257]_i16 to a range of i8s. You can also truncate |
159 | // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. |
160 | // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) |
161 | // will truncate to 0x7e, which is greater than 0 |
162 | APInt sminHighPart = range.smin().ashr(ShiftAmt: destWidth - 1); |
163 | APInt smaxHighPart = range.smax().ashr(ShiftAmt: destWidth - 1); |
164 | bool hasSignedOverflow = |
165 | (sminHighPart != smaxHighPart) && |
166 | !(sminHighPart.isAllOnes() && |
167 | (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && |
168 | !(sminHighPart.isZero() && smaxHighPart.isZero()); |
169 | APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(numBits: destWidth) |
170 | : range.smin().trunc(width: destWidth); |
171 | APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(numBits: destWidth) |
172 | : range.smax().trunc(width: destWidth); |
173 | return {umin, umax, smin, smax}; |
174 | } |
175 | |
176 | //===----------------------------------------------------------------------===// |
177 | // Addition |
178 | //===----------------------------------------------------------------------===// |
179 | |
180 | ConstantIntRanges |
181 | mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) { |
182 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
183 | ConstArithFn uadd = [](const APInt &a, |
184 | const APInt &b) -> std::optional<APInt> { |
185 | bool overflowed = false; |
186 | APInt result = a.uadd_ov(RHS: b, Overflow&: overflowed); |
187 | return overflowed ? std::optional<APInt>() : result; |
188 | }; |
189 | ConstArithFn sadd = [](const APInt &a, |
190 | const APInt &b) -> std::optional<APInt> { |
191 | bool overflowed = false; |
192 | APInt result = a.sadd_ov(RHS: b, Overflow&: overflowed); |
193 | return overflowed ? std::optional<APInt>() : result; |
194 | }; |
195 | |
196 | ConstantIntRanges urange = computeBoundsBy( |
197 | op: uadd, minLeft: lhs.umin(), minRight: rhs.umin(), maxLeft: lhs.umax(), maxRight: rhs.umax(), /*isSigned=*/false); |
198 | ConstantIntRanges srange = computeBoundsBy( |
199 | op: sadd, minLeft: lhs.smin(), minRight: rhs.smin(), maxLeft: lhs.smax(), maxRight: rhs.smax(), /*isSigned=*/true); |
200 | return urange.intersection(other: srange); |
201 | } |
202 | |
203 | //===----------------------------------------------------------------------===// |
204 | // Subtraction |
205 | //===----------------------------------------------------------------------===// |
206 | |
207 | ConstantIntRanges |
208 | mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) { |
209 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
210 | |
211 | ConstArithFn usub = [](const APInt &a, |
212 | const APInt &b) -> std::optional<APInt> { |
213 | bool overflowed = false; |
214 | APInt result = a.usub_ov(RHS: b, Overflow&: overflowed); |
215 | return overflowed ? std::optional<APInt>() : result; |
216 | }; |
217 | ConstArithFn ssub = [](const APInt &a, |
218 | const APInt &b) -> std::optional<APInt> { |
219 | bool overflowed = false; |
220 | APInt result = a.ssub_ov(RHS: b, Overflow&: overflowed); |
221 | return overflowed ? std::optional<APInt>() : result; |
222 | }; |
223 | ConstantIntRanges urange = computeBoundsBy( |
224 | op: usub, minLeft: lhs.umin(), minRight: rhs.umax(), maxLeft: lhs.umax(), maxRight: rhs.umin(), /*isSigned=*/false); |
225 | ConstantIntRanges srange = computeBoundsBy( |
226 | op: ssub, minLeft: lhs.smin(), minRight: rhs.smax(), maxLeft: lhs.smax(), maxRight: rhs.smin(), /*isSigned=*/true); |
227 | return urange.intersection(other: srange); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // Multiplication |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | ConstantIntRanges |
235 | mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) { |
236 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
237 | |
238 | ConstArithFn umul = [](const APInt &a, |
239 | const APInt &b) -> std::optional<APInt> { |
240 | bool overflowed = false; |
241 | APInt result = a.umul_ov(RHS: b, Overflow&: overflowed); |
242 | return overflowed ? std::optional<APInt>() : result; |
243 | }; |
244 | ConstArithFn smul = [](const APInt &a, |
245 | const APInt &b) -> std::optional<APInt> { |
246 | bool overflowed = false; |
247 | APInt result = a.smul_ov(RHS: b, Overflow&: overflowed); |
248 | return overflowed ? std::optional<APInt>() : result; |
249 | }; |
250 | |
251 | ConstantIntRanges urange = |
252 | minMaxBy(op: umul, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()}, |
253 | /*isSigned=*/false); |
254 | ConstantIntRanges srange = |
255 | minMaxBy(op: smul, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.smin(), rhs.smax()}, |
256 | /*isSigned=*/true); |
257 | return urange.intersection(other: srange); |
258 | } |
259 | |
260 | //===----------------------------------------------------------------------===// |
261 | // DivU, CeilDivU (Unsigned division) |
262 | //===----------------------------------------------------------------------===// |
263 | |
264 | /// Fix up division results (ex. for ceiling and floor), returning an APInt |
265 | /// if there has been no overflow |
266 | using DivisionFixupFn = function_ref<std::optional<APInt>( |
267 | const APInt &lhs, const APInt &rhs, const APInt &result)>; |
268 | |
269 | static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, |
270 | const ConstantIntRanges &rhs, |
271 | DivisionFixupFn fixup) { |
272 | const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), |
273 | &rhsMax = rhs.umax(); |
274 | |
275 | if (!rhsMin.isZero()) { |
276 | auto udiv = [&fixup](const APInt &a, |
277 | const APInt &b) -> std::optional<APInt> { |
278 | return fixup(a, b, a.udiv(RHS: b)); |
279 | }; |
280 | return minMaxBy(op: udiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax}, |
281 | /*isSigned=*/false); |
282 | } |
283 | // Otherwise, it's possible we might divide by 0. |
284 | return ConstantIntRanges::maxRange(bitwidth: rhsMin.getBitWidth()); |
285 | } |
286 | |
287 | ConstantIntRanges |
288 | mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) { |
289 | return inferDivURange(lhs: argRanges[0], rhs: argRanges[1], |
290 | fixup: [](const APInt &lhs, const APInt &rhs, |
291 | const APInt &result) { return result; }); |
292 | } |
293 | |
294 | ConstantIntRanges |
295 | mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) { |
296 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
297 | |
298 | DivisionFixupFn ceilDivUIFix = |
299 | [](const APInt &lhs, const APInt &rhs, |
300 | const APInt &result) -> std::optional<APInt> { |
301 | if (!lhs.urem(RHS: rhs).isZero()) { |
302 | bool overflowed = false; |
303 | APInt corrected = |
304 | result.uadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed); |
305 | return overflowed ? std::optional<APInt>() : corrected; |
306 | } |
307 | return result; |
308 | }; |
309 | return inferDivURange(lhs, rhs, fixup: ceilDivUIFix); |
310 | } |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // DivS, CeilDivS, FloorDivS (Signed division) |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, |
317 | const ConstantIntRanges &rhs, |
318 | DivisionFixupFn fixup) { |
319 | const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), |
320 | &rhsMax = rhs.smax(); |
321 | bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); |
322 | |
323 | if (canDivide) { |
324 | auto sdiv = [&fixup](const APInt &a, |
325 | const APInt &b) -> std::optional<APInt> { |
326 | bool overflowed = false; |
327 | APInt result = a.sdiv_ov(RHS: b, Overflow&: overflowed); |
328 | return overflowed ? std::optional<APInt>() : fixup(a, b, result); |
329 | }; |
330 | return minMaxBy(op: sdiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax}, |
331 | /*isSigned=*/true); |
332 | } |
333 | return ConstantIntRanges::maxRange(bitwidth: rhsMin.getBitWidth()); |
334 | } |
335 | |
336 | ConstantIntRanges |
337 | mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) { |
338 | return inferDivSRange(lhs: argRanges[0], rhs: argRanges[1], |
339 | fixup: [](const APInt &lhs, const APInt &rhs, |
340 | const APInt &result) { return result; }); |
341 | } |
342 | |
343 | ConstantIntRanges |
344 | mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) { |
345 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
346 | |
347 | DivisionFixupFn ceilDivSIFix = |
348 | [](const APInt &lhs, const APInt &rhs, |
349 | const APInt &result) -> std::optional<APInt> { |
350 | if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { |
351 | bool overflowed = false; |
352 | APInt corrected = |
353 | result.sadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed); |
354 | return overflowed ? std::optional<APInt>() : corrected; |
355 | } |
356 | return result; |
357 | }; |
358 | return inferDivSRange(lhs, rhs, fixup: ceilDivSIFix); |
359 | } |
360 | |
361 | ConstantIntRanges |
362 | mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) { |
363 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
364 | |
365 | DivisionFixupFn floorDivSIFix = |
366 | [](const APInt &lhs, const APInt &rhs, |
367 | const APInt &result) -> std::optional<APInt> { |
368 | if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { |
369 | bool overflowed = false; |
370 | APInt corrected = |
371 | result.ssub_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed); |
372 | return overflowed ? std::optional<APInt>() : corrected; |
373 | } |
374 | return result; |
375 | }; |
376 | return inferDivSRange(lhs, rhs, fixup: floorDivSIFix); |
377 | } |
378 | |
379 | //===----------------------------------------------------------------------===// |
380 | // Signed remainder (RemS) |
381 | //===----------------------------------------------------------------------===// |
382 | |
383 | ConstantIntRanges |
384 | mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) { |
385 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
386 | const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), |
387 | &rhsMax = rhs.smax(); |
388 | |
389 | unsigned width = rhsMax.getBitWidth(); |
390 | APInt smin = APInt::getSignedMinValue(numBits: width); |
391 | APInt smax = APInt::getSignedMaxValue(numBits: width); |
392 | // No bounds if zero could be a divisor. |
393 | bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); |
394 | if (canBound) { |
395 | APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); |
396 | bool canNegativeDividend = lhsMin.isNegative(); |
397 | bool canPositiveDividend = lhsMax.isStrictlyPositive(); |
398 | APInt zero = APInt::getZero(numBits: maxDivisor.getBitWidth()); |
399 | APInt maxPositiveResult = maxDivisor - 1; |
400 | APInt minNegativeResult = -maxPositiveResult; |
401 | smin = canNegativeDividend ? minNegativeResult : zero; |
402 | smax = canPositiveDividend ? maxPositiveResult : zero; |
403 | // Special case: sweeping out a contiguous range in N/[modulus]. |
404 | if (rhsMin == rhsMax) { |
405 | if ((lhsMax - lhsMin).ult(RHS: maxDivisor)) { |
406 | APInt minRem = lhsMin.srem(RHS: maxDivisor); |
407 | APInt maxRem = lhsMax.srem(RHS: maxDivisor); |
408 | if (minRem.sle(RHS: maxRem)) { |
409 | smin = minRem; |
410 | smax = maxRem; |
411 | } |
412 | } |
413 | } |
414 | } |
415 | return ConstantIntRanges::fromSigned(smin, smax); |
416 | } |
417 | |
418 | //===----------------------------------------------------------------------===// |
419 | // Unsigned remainder (RemU) |
420 | //===----------------------------------------------------------------------===// |
421 | |
422 | ConstantIntRanges |
423 | mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) { |
424 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
425 | const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); |
426 | |
427 | unsigned width = rhsMin.getBitWidth(); |
428 | APInt umin = APInt::getZero(numBits: width); |
429 | APInt umax = APInt::getMaxValue(numBits: width); |
430 | |
431 | if (!rhsMin.isZero()) { |
432 | umax = rhsMax - 1; |
433 | // Special case: sweeping out a contiguous range in N/[modulus] |
434 | if (rhsMin == rhsMax) { |
435 | const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); |
436 | if ((lhsMax - lhsMin).ult(RHS: rhsMax)) { |
437 | APInt minRem = lhsMin.urem(RHS: rhsMax); |
438 | APInt maxRem = lhsMax.urem(RHS: rhsMax); |
439 | if (minRem.ule(RHS: maxRem)) { |
440 | umin = minRem; |
441 | umax = maxRem; |
442 | } |
443 | } |
444 | } |
445 | } |
446 | return ConstantIntRanges::fromUnsigned(umin, umax); |
447 | } |
448 | |
449 | //===----------------------------------------------------------------------===// |
450 | // Max and min (MaxS, MaxU, MinS, MinU) |
451 | //===----------------------------------------------------------------------===// |
452 | |
453 | ConstantIntRanges |
454 | mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) { |
455 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
456 | |
457 | const APInt &smin = lhs.smin().sgt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin(); |
458 | const APInt &smax = lhs.smax().sgt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax(); |
459 | return ConstantIntRanges::fromSigned(smin, smax); |
460 | } |
461 | |
462 | ConstantIntRanges |
463 | mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) { |
464 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
465 | |
466 | const APInt &umin = lhs.umin().ugt(RHS: rhs.umin()) ? lhs.umin() : rhs.umin(); |
467 | const APInt &umax = lhs.umax().ugt(RHS: rhs.umax()) ? lhs.umax() : rhs.umax(); |
468 | return ConstantIntRanges::fromUnsigned(umin, umax); |
469 | } |
470 | |
471 | ConstantIntRanges |
472 | mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) { |
473 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
474 | |
475 | const APInt &smin = lhs.smin().slt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin(); |
476 | const APInt &smax = lhs.smax().slt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax(); |
477 | return ConstantIntRanges::fromSigned(smin, smax); |
478 | } |
479 | |
480 | ConstantIntRanges |
481 | mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) { |
482 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
483 | |
484 | const APInt &umin = lhs.umin().ult(RHS: rhs.umin()) ? lhs.umin() : rhs.umin(); |
485 | const APInt &umax = lhs.umax().ult(RHS: rhs.umax()) ? lhs.umax() : rhs.umax(); |
486 | return ConstantIntRanges::fromUnsigned(umin, umax); |
487 | } |
488 | |
489 | //===----------------------------------------------------------------------===// |
490 | // Bitwise operators (And, Or, Xor) |
491 | //===----------------------------------------------------------------------===// |
492 | |
493 | /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, |
494 | /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits |
495 | /// that both bonuds have in common. This gives us a consertive approximation |
496 | /// for what values can be passed to bitwise operations. |
497 | static std::tuple<APInt, APInt> |
498 | widenBitwiseBounds(const ConstantIntRanges &bound) { |
499 | APInt leftVal = bound.umin(), rightVal = bound.umax(); |
500 | unsigned bitwidth = leftVal.getBitWidth(); |
501 | unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); |
502 | leftVal.clearLowBits(loBits: differingBits); |
503 | rightVal.setLowBits(differingBits); |
504 | return std::make_tuple(args: std::move(leftVal), args: std::move(rightVal)); |
505 | } |
506 | |
507 | ConstantIntRanges |
508 | mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) { |
509 | auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]); |
510 | auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]); |
511 | auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> { |
512 | return a & b; |
513 | }; |
514 | return minMaxBy(op: andi, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes}, |
515 | /*isSigned=*/false); |
516 | } |
517 | |
518 | ConstantIntRanges |
519 | mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) { |
520 | auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]); |
521 | auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]); |
522 | auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { |
523 | return a | b; |
524 | }; |
525 | return minMaxBy(op: ori, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes}, |
526 | /*isSigned=*/false); |
527 | } |
528 | |
529 | ConstantIntRanges |
530 | mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) { |
531 | auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]); |
532 | auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]); |
533 | auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { |
534 | return a ^ b; |
535 | }; |
536 | return minMaxBy(op: xori, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes}, |
537 | /*isSigned=*/false); |
538 | } |
539 | |
540 | //===----------------------------------------------------------------------===// |
541 | // Shifts (Shl, ShrS, ShrU) |
542 | //===----------------------------------------------------------------------===// |
543 | |
544 | ConstantIntRanges |
545 | mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) { |
546 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
547 | ConstArithFn shl = [](const APInt &l, |
548 | const APInt &r) -> std::optional<APInt> { |
549 | return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.shl(ShiftAmt: r); |
550 | }; |
551 | ConstantIntRanges urange = |
552 | minMaxBy(op: shl, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()}, |
553 | /*isSigned=*/false); |
554 | ConstantIntRanges srange = |
555 | minMaxBy(op: shl, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.umin(), rhs.umax()}, |
556 | /*isSigned=*/true); |
557 | return urange.intersection(other: srange); |
558 | } |
559 | |
560 | ConstantIntRanges |
561 | mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) { |
562 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
563 | |
564 | ConstArithFn ashr = [](const APInt &l, |
565 | const APInt &r) -> std::optional<APInt> { |
566 | return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.ashr(ShiftAmt: r); |
567 | }; |
568 | |
569 | return minMaxBy(op: ashr, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.umin(), rhs.umax()}, |
570 | /*isSigned=*/true); |
571 | } |
572 | |
573 | ConstantIntRanges |
574 | mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) { |
575 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
576 | |
577 | ConstArithFn lshr = [](const APInt &l, |
578 | const APInt &r) -> std::optional<APInt> { |
579 | return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.lshr(ShiftAmt: r); |
580 | }; |
581 | return minMaxBy(op: lshr, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()}, |
582 | /*isSigned=*/false); |
583 | } |
584 | |
585 | //===----------------------------------------------------------------------===// |
586 | // Comparisons (Cmp) |
587 | //===----------------------------------------------------------------------===// |
588 | |
589 | static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { |
590 | switch (pred) { |
591 | case intrange::CmpPredicate::eq: |
592 | return intrange::CmpPredicate::ne; |
593 | case intrange::CmpPredicate::ne: |
594 | return intrange::CmpPredicate::eq; |
595 | case intrange::CmpPredicate::slt: |
596 | return intrange::CmpPredicate::sge; |
597 | case intrange::CmpPredicate::sle: |
598 | return intrange::CmpPredicate::sgt; |
599 | case intrange::CmpPredicate::sgt: |
600 | return intrange::CmpPredicate::sle; |
601 | case intrange::CmpPredicate::sge: |
602 | return intrange::CmpPredicate::slt; |
603 | case intrange::CmpPredicate::ult: |
604 | return intrange::CmpPredicate::uge; |
605 | case intrange::CmpPredicate::ule: |
606 | return intrange::CmpPredicate::ugt; |
607 | case intrange::CmpPredicate::ugt: |
608 | return intrange::CmpPredicate::ule; |
609 | case intrange::CmpPredicate::uge: |
610 | return intrange::CmpPredicate::ult; |
611 | } |
612 | llvm_unreachable("unknown cmp predicate value" ); |
613 | } |
614 | |
615 | static bool isStaticallyTrue(intrange::CmpPredicate pred, |
616 | const ConstantIntRanges &lhs, |
617 | const ConstantIntRanges &rhs) { |
618 | switch (pred) { |
619 | case intrange::CmpPredicate::sle: |
620 | return lhs.smax().sle(RHS: rhs.smin()); |
621 | case intrange::CmpPredicate::slt: |
622 | return lhs.smax().slt(RHS: rhs.smin()); |
623 | case intrange::CmpPredicate::ule: |
624 | return lhs.umax().ule(RHS: rhs.umin()); |
625 | case intrange::CmpPredicate::ult: |
626 | return lhs.umax().ult(RHS: rhs.umin()); |
627 | case intrange::CmpPredicate::sge: |
628 | return lhs.smin().sge(RHS: rhs.smax()); |
629 | case intrange::CmpPredicate::sgt: |
630 | return lhs.smin().sgt(RHS: rhs.smax()); |
631 | case intrange::CmpPredicate::uge: |
632 | return lhs.umin().uge(RHS: rhs.umax()); |
633 | case intrange::CmpPredicate::ugt: |
634 | return lhs.umin().ugt(RHS: rhs.umax()); |
635 | case intrange::CmpPredicate::eq: { |
636 | std::optional<APInt> lhsConst = lhs.getConstantValue(); |
637 | std::optional<APInt> rhsConst = rhs.getConstantValue(); |
638 | return lhsConst && rhsConst && lhsConst == rhsConst; |
639 | } |
640 | case intrange::CmpPredicate::ne: { |
641 | // While equality requires that there is an interpration of the preceeding |
642 | // computations that produces equal constants, whether that be signed or |
643 | // unsigned, statically determining inequality requires that neither |
644 | // interpretation produce potentially overlapping ranges. |
645 | bool sne = isStaticallyTrue(pred: intrange::CmpPredicate::slt, lhs, rhs) || |
646 | isStaticallyTrue(pred: intrange::CmpPredicate::sgt, lhs, rhs); |
647 | bool une = isStaticallyTrue(pred: intrange::CmpPredicate::ult, lhs, rhs) || |
648 | isStaticallyTrue(pred: intrange::CmpPredicate::ugt, lhs, rhs); |
649 | return sne && une; |
650 | } |
651 | } |
652 | return false; |
653 | } |
654 | |
655 | std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred, |
656 | const ConstantIntRanges &lhs, |
657 | const ConstantIntRanges &rhs) { |
658 | if (isStaticallyTrue(pred, lhs, rhs)) |
659 | return true; |
660 | if (isStaticallyTrue(pred: invertPredicate(pred), lhs, rhs)) |
661 | return false; |
662 | return std::nullopt; |
663 | } |
664 | |