1//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains implementations of range inference for operations that are
10// common to both the `arith` and `index` dialects to facilitate reuse.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
15
16#include "mlir/Interfaces/InferIntRangeInterface.h"
17
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/STLExtras.h"
20
21#include "llvm/Support/Debug.h"
22
23#include <iterator>
24#include <optional>
25
26using namespace mlir;
27
28#define DEBUG_TYPE "int-range-analysis"
29
30//===----------------------------------------------------------------------===//
31// General utilities
32//===----------------------------------------------------------------------===//
33
34/// Function that evaluates the result of doing something on arithmetic
35/// constants and returns std::nullopt on overflow.
36using ConstArithFn =
37 function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
38
39/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
40/// If either computation overflows, make the result unbounded.
41static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
42 const APInt &minRight,
43 const APInt &maxLeft,
44 const APInt &maxRight, bool isSigned) {
45 std::optional<APInt> maybeMin = op(minLeft, minRight);
46 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
47 if (maybeMin && maybeMax)
48 return ConstantIntRanges::range(min: *maybeMin, max: *maybeMax, isSigned);
49 return ConstantIntRanges::maxRange(bitwidth: minLeft.getBitWidth());
50}
51
52/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
53/// ignoring unbounded values. Returns the maximal range if `op` overflows.
54static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
55 ArrayRef<APInt> rhs, bool isSigned) {
56 unsigned width = lhs[0].getBitWidth();
57 APInt min =
58 isSigned ? APInt::getSignedMaxValue(numBits: width) : APInt::getMaxValue(numBits: width);
59 APInt max =
60 isSigned ? APInt::getSignedMinValue(numBits: width) : APInt::getZero(numBits: width);
61 for (const APInt &left : lhs) {
62 for (const APInt &right : rhs) {
63 std::optional<APInt> maybeThisResult = op(left, right);
64 if (!maybeThisResult)
65 return ConstantIntRanges::maxRange(bitwidth: width);
66 APInt result = std::move(*maybeThisResult);
67 min = (isSigned ? result.slt(RHS: min) : result.ult(RHS: min)) ? result : min;
68 max = (isSigned ? result.sgt(RHS: max) : result.ugt(RHS: max)) ? result : max;
69 }
70 }
71 return ConstantIntRanges::range(min, max, isSigned);
72}
73
74//===----------------------------------------------------------------------===//
75// Ext, trunc, index op handling
76//===----------------------------------------------------------------------===//
77
78ConstantIntRanges
79mlir::intrange::inferIndexOp(InferRangeFn inferFn,
80 ArrayRef<ConstantIntRanges> argRanges,
81 intrange::CmpMode mode) {
82 ConstantIntRanges sixtyFour = inferFn(argRanges);
83 SmallVector<ConstantIntRanges, 2> truncated;
84 llvm::transform(Range&: argRanges, d_first: std::back_inserter(x&: truncated),
85 F: [](const ConstantIntRanges &range) {
86 return truncRange(range, /*destWidth=*/indexMinWidth);
87 });
88 ConstantIntRanges thirtyTwo = inferFn(truncated);
89 ConstantIntRanges thirtyTwoAsSixtyFour =
90 extRange(range: thirtyTwo, /*destWidth=*/indexMaxWidth);
91 ConstantIntRanges sixtyFourAsThirtyTwo =
92 truncRange(range: sixtyFour, /*destWidth=*/indexMinWidth);
93
94 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
95 << " 32-bit = " << thirtyTwo << "\n");
96 bool truncEqual = false;
97 switch (mode) {
98 case intrange::CmpMode::Both:
99 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
100 break;
101 case intrange::CmpMode::Signed:
102 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
103 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
104 break;
105 case intrange::CmpMode::Unsigned:
106 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
107 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
108 break;
109 }
110 if (truncEqual)
111 // Returing the 64-bit result preserves more information.
112 return sixtyFour;
113 ConstantIntRanges merged = sixtyFour.rangeUnion(other: thirtyTwoAsSixtyFour);
114 return merged;
115}
116
117ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
118 unsigned int destWidth) {
119 APInt umin = range.umin().zext(width: destWidth);
120 APInt umax = range.umax().zext(width: destWidth);
121 APInt smin = range.smin().sext(width: destWidth);
122 APInt smax = range.smax().sext(width: destWidth);
123 return {umin, umax, smin, smax};
124}
125
126ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
127 unsigned destWidth) {
128 APInt umin = range.umin().zext(width: destWidth);
129 APInt umax = range.umax().zext(width: destWidth);
130 return ConstantIntRanges::fromUnsigned(umin, umax);
131}
132
133ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
134 unsigned destWidth) {
135 APInt smin = range.smin().sext(width: destWidth);
136 APInt smax = range.smax().sext(width: destWidth);
137 return ConstantIntRanges::fromSigned(smin, smax);
138}
139
140ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
141 unsigned int destWidth) {
142 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
143 // the range of the resulting value is not contiguous ind includes 0.
144 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
145 // but you can't truncate [255, 257] similarly.
146 bool hasUnsignedRollover =
147 range.umin().lshr(shiftAmt: destWidth) != range.umax().lshr(shiftAmt: destWidth);
148 APInt umin = hasUnsignedRollover ? APInt::getZero(numBits: destWidth)
149 : range.umin().trunc(width: destWidth);
150 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(numBits: destWidth)
151 : range.umax().trunc(width: destWidth);
152
153 // Signed post-truncation rollover will not occur when either:
154 // - The high parts of the min and max, plus the sign bit, are the same
155 // - The high halves + sign bit of the min and max are either all 1s or all 0s
156 // and you won't create a [positive, negative] range by truncating.
157 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
158 // but not [255, 257]_i16 to a range of i8s. You can also truncate
159 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
160 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
161 // will truncate to 0x7e, which is greater than 0
162 APInt sminHighPart = range.smin().ashr(ShiftAmt: destWidth - 1);
163 APInt smaxHighPart = range.smax().ashr(ShiftAmt: destWidth - 1);
164 bool hasSignedOverflow =
165 (sminHighPart != smaxHighPart) &&
166 !(sminHighPart.isAllOnes() &&
167 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
168 !(sminHighPart.isZero() && smaxHighPart.isZero());
169 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(numBits: destWidth)
170 : range.smin().trunc(width: destWidth);
171 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(numBits: destWidth)
172 : range.smax().trunc(width: destWidth);
173 return {umin, umax, smin, smax};
174}
175
176//===----------------------------------------------------------------------===//
177// Addition
178//===----------------------------------------------------------------------===//
179
180ConstantIntRanges
181mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
182 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
183 ConstArithFn uadd = [](const APInt &a,
184 const APInt &b) -> std::optional<APInt> {
185 bool overflowed = false;
186 APInt result = a.uadd_ov(RHS: b, Overflow&: overflowed);
187 return overflowed ? std::optional<APInt>() : result;
188 };
189 ConstArithFn sadd = [](const APInt &a,
190 const APInt &b) -> std::optional<APInt> {
191 bool overflowed = false;
192 APInt result = a.sadd_ov(RHS: b, Overflow&: overflowed);
193 return overflowed ? std::optional<APInt>() : result;
194 };
195
196 ConstantIntRanges urange = computeBoundsBy(
197 op: uadd, minLeft: lhs.umin(), minRight: rhs.umin(), maxLeft: lhs.umax(), maxRight: rhs.umax(), /*isSigned=*/false);
198 ConstantIntRanges srange = computeBoundsBy(
199 op: sadd, minLeft: lhs.smin(), minRight: rhs.smin(), maxLeft: lhs.smax(), maxRight: rhs.smax(), /*isSigned=*/true);
200 return urange.intersection(other: srange);
201}
202
203//===----------------------------------------------------------------------===//
204// Subtraction
205//===----------------------------------------------------------------------===//
206
207ConstantIntRanges
208mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
209 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
210
211 ConstArithFn usub = [](const APInt &a,
212 const APInt &b) -> std::optional<APInt> {
213 bool overflowed = false;
214 APInt result = a.usub_ov(RHS: b, Overflow&: overflowed);
215 return overflowed ? std::optional<APInt>() : result;
216 };
217 ConstArithFn ssub = [](const APInt &a,
218 const APInt &b) -> std::optional<APInt> {
219 bool overflowed = false;
220 APInt result = a.ssub_ov(RHS: b, Overflow&: overflowed);
221 return overflowed ? std::optional<APInt>() : result;
222 };
223 ConstantIntRanges urange = computeBoundsBy(
224 op: usub, minLeft: lhs.umin(), minRight: rhs.umax(), maxLeft: lhs.umax(), maxRight: rhs.umin(), /*isSigned=*/false);
225 ConstantIntRanges srange = computeBoundsBy(
226 op: ssub, minLeft: lhs.smin(), minRight: rhs.smax(), maxLeft: lhs.smax(), maxRight: rhs.smin(), /*isSigned=*/true);
227 return urange.intersection(other: srange);
228}
229
230//===----------------------------------------------------------------------===//
231// Multiplication
232//===----------------------------------------------------------------------===//
233
234ConstantIntRanges
235mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
236 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
237
238 ConstArithFn umul = [](const APInt &a,
239 const APInt &b) -> std::optional<APInt> {
240 bool overflowed = false;
241 APInt result = a.umul_ov(RHS: b, Overflow&: overflowed);
242 return overflowed ? std::optional<APInt>() : result;
243 };
244 ConstArithFn smul = [](const APInt &a,
245 const APInt &b) -> std::optional<APInt> {
246 bool overflowed = false;
247 APInt result = a.smul_ov(RHS: b, Overflow&: overflowed);
248 return overflowed ? std::optional<APInt>() : result;
249 };
250
251 ConstantIntRanges urange =
252 minMaxBy(op: umul, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()},
253 /*isSigned=*/false);
254 ConstantIntRanges srange =
255 minMaxBy(op: smul, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.smin(), rhs.smax()},
256 /*isSigned=*/true);
257 return urange.intersection(other: srange);
258}
259
260//===----------------------------------------------------------------------===//
261// DivU, CeilDivU (Unsigned division)
262//===----------------------------------------------------------------------===//
263
264/// Fix up division results (ex. for ceiling and floor), returning an APInt
265/// if there has been no overflow
266using DivisionFixupFn = function_ref<std::optional<APInt>(
267 const APInt &lhs, const APInt &rhs, const APInt &result)>;
268
269static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
270 const ConstantIntRanges &rhs,
271 DivisionFixupFn fixup) {
272 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
273 &rhsMax = rhs.umax();
274
275 if (!rhsMin.isZero()) {
276 auto udiv = [&fixup](const APInt &a,
277 const APInt &b) -> std::optional<APInt> {
278 return fixup(a, b, a.udiv(RHS: b));
279 };
280 return minMaxBy(op: udiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax},
281 /*isSigned=*/false);
282 }
283 // Otherwise, it's possible we might divide by 0.
284 return ConstantIntRanges::maxRange(bitwidth: rhsMin.getBitWidth());
285}
286
287ConstantIntRanges
288mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
289 return inferDivURange(lhs: argRanges[0], rhs: argRanges[1],
290 fixup: [](const APInt &lhs, const APInt &rhs,
291 const APInt &result) { return result; });
292}
293
294ConstantIntRanges
295mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
296 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
297
298 DivisionFixupFn ceilDivUIFix =
299 [](const APInt &lhs, const APInt &rhs,
300 const APInt &result) -> std::optional<APInt> {
301 if (!lhs.urem(RHS: rhs).isZero()) {
302 bool overflowed = false;
303 APInt corrected =
304 result.uadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
305 return overflowed ? std::optional<APInt>() : corrected;
306 }
307 return result;
308 };
309 return inferDivURange(lhs, rhs, fixup: ceilDivUIFix);
310}
311
312//===----------------------------------------------------------------------===//
313// DivS, CeilDivS, FloorDivS (Signed division)
314//===----------------------------------------------------------------------===//
315
316static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
317 const ConstantIntRanges &rhs,
318 DivisionFixupFn fixup) {
319 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
320 &rhsMax = rhs.smax();
321 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
322
323 if (canDivide) {
324 auto sdiv = [&fixup](const APInt &a,
325 const APInt &b) -> std::optional<APInt> {
326 bool overflowed = false;
327 APInt result = a.sdiv_ov(RHS: b, Overflow&: overflowed);
328 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
329 };
330 return minMaxBy(op: sdiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax},
331 /*isSigned=*/true);
332 }
333 return ConstantIntRanges::maxRange(bitwidth: rhsMin.getBitWidth());
334}
335
336ConstantIntRanges
337mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
338 return inferDivSRange(lhs: argRanges[0], rhs: argRanges[1],
339 fixup: [](const APInt &lhs, const APInt &rhs,
340 const APInt &result) { return result; });
341}
342
343ConstantIntRanges
344mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
345 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
346
347 DivisionFixupFn ceilDivSIFix =
348 [](const APInt &lhs, const APInt &rhs,
349 const APInt &result) -> std::optional<APInt> {
350 if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
351 bool overflowed = false;
352 APInt corrected =
353 result.sadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
354 return overflowed ? std::optional<APInt>() : corrected;
355 }
356 return result;
357 };
358 return inferDivSRange(lhs, rhs, fixup: ceilDivSIFix);
359}
360
361ConstantIntRanges
362mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
363 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
364
365 DivisionFixupFn floorDivSIFix =
366 [](const APInt &lhs, const APInt &rhs,
367 const APInt &result) -> std::optional<APInt> {
368 if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
369 bool overflowed = false;
370 APInt corrected =
371 result.ssub_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
372 return overflowed ? std::optional<APInt>() : corrected;
373 }
374 return result;
375 };
376 return inferDivSRange(lhs, rhs, fixup: floorDivSIFix);
377}
378
379//===----------------------------------------------------------------------===//
380// Signed remainder (RemS)
381//===----------------------------------------------------------------------===//
382
383ConstantIntRanges
384mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
385 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
386 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
387 &rhsMax = rhs.smax();
388
389 unsigned width = rhsMax.getBitWidth();
390 APInt smin = APInt::getSignedMinValue(numBits: width);
391 APInt smax = APInt::getSignedMaxValue(numBits: width);
392 // No bounds if zero could be a divisor.
393 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
394 if (canBound) {
395 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
396 bool canNegativeDividend = lhsMin.isNegative();
397 bool canPositiveDividend = lhsMax.isStrictlyPositive();
398 APInt zero = APInt::getZero(numBits: maxDivisor.getBitWidth());
399 APInt maxPositiveResult = maxDivisor - 1;
400 APInt minNegativeResult = -maxPositiveResult;
401 smin = canNegativeDividend ? minNegativeResult : zero;
402 smax = canPositiveDividend ? maxPositiveResult : zero;
403 // Special case: sweeping out a contiguous range in N/[modulus].
404 if (rhsMin == rhsMax) {
405 if ((lhsMax - lhsMin).ult(RHS: maxDivisor)) {
406 APInt minRem = lhsMin.srem(RHS: maxDivisor);
407 APInt maxRem = lhsMax.srem(RHS: maxDivisor);
408 if (minRem.sle(RHS: maxRem)) {
409 smin = minRem;
410 smax = maxRem;
411 }
412 }
413 }
414 }
415 return ConstantIntRanges::fromSigned(smin, smax);
416}
417
418//===----------------------------------------------------------------------===//
419// Unsigned remainder (RemU)
420//===----------------------------------------------------------------------===//
421
422ConstantIntRanges
423mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
424 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
425 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
426
427 unsigned width = rhsMin.getBitWidth();
428 APInt umin = APInt::getZero(numBits: width);
429 APInt umax = APInt::getMaxValue(numBits: width);
430
431 if (!rhsMin.isZero()) {
432 umax = rhsMax - 1;
433 // Special case: sweeping out a contiguous range in N/[modulus]
434 if (rhsMin == rhsMax) {
435 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
436 if ((lhsMax - lhsMin).ult(RHS: rhsMax)) {
437 APInt minRem = lhsMin.urem(RHS: rhsMax);
438 APInt maxRem = lhsMax.urem(RHS: rhsMax);
439 if (minRem.ule(RHS: maxRem)) {
440 umin = minRem;
441 umax = maxRem;
442 }
443 }
444 }
445 }
446 return ConstantIntRanges::fromUnsigned(umin, umax);
447}
448
449//===----------------------------------------------------------------------===//
450// Max and min (MaxS, MaxU, MinS, MinU)
451//===----------------------------------------------------------------------===//
452
453ConstantIntRanges
454mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
455 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
456
457 const APInt &smin = lhs.smin().sgt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin();
458 const APInt &smax = lhs.smax().sgt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax();
459 return ConstantIntRanges::fromSigned(smin, smax);
460}
461
462ConstantIntRanges
463mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
464 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
465
466 const APInt &umin = lhs.umin().ugt(RHS: rhs.umin()) ? lhs.umin() : rhs.umin();
467 const APInt &umax = lhs.umax().ugt(RHS: rhs.umax()) ? lhs.umax() : rhs.umax();
468 return ConstantIntRanges::fromUnsigned(umin, umax);
469}
470
471ConstantIntRanges
472mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
473 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
474
475 const APInt &smin = lhs.smin().slt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin();
476 const APInt &smax = lhs.smax().slt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax();
477 return ConstantIntRanges::fromSigned(smin, smax);
478}
479
480ConstantIntRanges
481mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
482 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
483
484 const APInt &umin = lhs.umin().ult(RHS: rhs.umin()) ? lhs.umin() : rhs.umin();
485 const APInt &umax = lhs.umax().ult(RHS: rhs.umax()) ? lhs.umax() : rhs.umax();
486 return ConstantIntRanges::fromUnsigned(umin, umax);
487}
488
489//===----------------------------------------------------------------------===//
490// Bitwise operators (And, Or, Xor)
491//===----------------------------------------------------------------------===//
492
493/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
494/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
495/// that both bonuds have in common. This gives us a consertive approximation
496/// for what values can be passed to bitwise operations.
497static std::tuple<APInt, APInt>
498widenBitwiseBounds(const ConstantIntRanges &bound) {
499 APInt leftVal = bound.umin(), rightVal = bound.umax();
500 unsigned bitwidth = leftVal.getBitWidth();
501 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
502 leftVal.clearLowBits(loBits: differingBits);
503 rightVal.setLowBits(differingBits);
504 return std::make_tuple(args: std::move(leftVal), args: std::move(rightVal));
505}
506
507ConstantIntRanges
508mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
509 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]);
510 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]);
511 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
512 return a & b;
513 };
514 return minMaxBy(op: andi, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes},
515 /*isSigned=*/false);
516}
517
518ConstantIntRanges
519mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
520 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]);
521 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]);
522 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
523 return a | b;
524 };
525 return minMaxBy(op: ori, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes},
526 /*isSigned=*/false);
527}
528
529ConstantIntRanges
530mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
531 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]);
532 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]);
533 auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
534 return a ^ b;
535 };
536 return minMaxBy(op: xori, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes},
537 /*isSigned=*/false);
538}
539
540//===----------------------------------------------------------------------===//
541// Shifts (Shl, ShrS, ShrU)
542//===----------------------------------------------------------------------===//
543
544ConstantIntRanges
545mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
546 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
547 ConstArithFn shl = [](const APInt &l,
548 const APInt &r) -> std::optional<APInt> {
549 return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.shl(ShiftAmt: r);
550 };
551 ConstantIntRanges urange =
552 minMaxBy(op: shl, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()},
553 /*isSigned=*/false);
554 ConstantIntRanges srange =
555 minMaxBy(op: shl, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.umin(), rhs.umax()},
556 /*isSigned=*/true);
557 return urange.intersection(other: srange);
558}
559
560ConstantIntRanges
561mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
562 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
563
564 ConstArithFn ashr = [](const APInt &l,
565 const APInt &r) -> std::optional<APInt> {
566 return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.ashr(ShiftAmt: r);
567 };
568
569 return minMaxBy(op: ashr, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.umin(), rhs.umax()},
570 /*isSigned=*/true);
571}
572
573ConstantIntRanges
574mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
575 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
576
577 ConstArithFn lshr = [](const APInt &l,
578 const APInt &r) -> std::optional<APInt> {
579 return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.lshr(ShiftAmt: r);
580 };
581 return minMaxBy(op: lshr, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()},
582 /*isSigned=*/false);
583}
584
585//===----------------------------------------------------------------------===//
586// Comparisons (Cmp)
587//===----------------------------------------------------------------------===//
588
589static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
590 switch (pred) {
591 case intrange::CmpPredicate::eq:
592 return intrange::CmpPredicate::ne;
593 case intrange::CmpPredicate::ne:
594 return intrange::CmpPredicate::eq;
595 case intrange::CmpPredicate::slt:
596 return intrange::CmpPredicate::sge;
597 case intrange::CmpPredicate::sle:
598 return intrange::CmpPredicate::sgt;
599 case intrange::CmpPredicate::sgt:
600 return intrange::CmpPredicate::sle;
601 case intrange::CmpPredicate::sge:
602 return intrange::CmpPredicate::slt;
603 case intrange::CmpPredicate::ult:
604 return intrange::CmpPredicate::uge;
605 case intrange::CmpPredicate::ule:
606 return intrange::CmpPredicate::ugt;
607 case intrange::CmpPredicate::ugt:
608 return intrange::CmpPredicate::ule;
609 case intrange::CmpPredicate::uge:
610 return intrange::CmpPredicate::ult;
611 }
612 llvm_unreachable("unknown cmp predicate value");
613}
614
615static bool isStaticallyTrue(intrange::CmpPredicate pred,
616 const ConstantIntRanges &lhs,
617 const ConstantIntRanges &rhs) {
618 switch (pred) {
619 case intrange::CmpPredicate::sle:
620 return lhs.smax().sle(RHS: rhs.smin());
621 case intrange::CmpPredicate::slt:
622 return lhs.smax().slt(RHS: rhs.smin());
623 case intrange::CmpPredicate::ule:
624 return lhs.umax().ule(RHS: rhs.umin());
625 case intrange::CmpPredicate::ult:
626 return lhs.umax().ult(RHS: rhs.umin());
627 case intrange::CmpPredicate::sge:
628 return lhs.smin().sge(RHS: rhs.smax());
629 case intrange::CmpPredicate::sgt:
630 return lhs.smin().sgt(RHS: rhs.smax());
631 case intrange::CmpPredicate::uge:
632 return lhs.umin().uge(RHS: rhs.umax());
633 case intrange::CmpPredicate::ugt:
634 return lhs.umin().ugt(RHS: rhs.umax());
635 case intrange::CmpPredicate::eq: {
636 std::optional<APInt> lhsConst = lhs.getConstantValue();
637 std::optional<APInt> rhsConst = rhs.getConstantValue();
638 return lhsConst && rhsConst && lhsConst == rhsConst;
639 }
640 case intrange::CmpPredicate::ne: {
641 // While equality requires that there is an interpration of the preceeding
642 // computations that produces equal constants, whether that be signed or
643 // unsigned, statically determining inequality requires that neither
644 // interpretation produce potentially overlapping ranges.
645 bool sne = isStaticallyTrue(pred: intrange::CmpPredicate::slt, lhs, rhs) ||
646 isStaticallyTrue(pred: intrange::CmpPredicate::sgt, lhs, rhs);
647 bool une = isStaticallyTrue(pred: intrange::CmpPredicate::ult, lhs, rhs) ||
648 isStaticallyTrue(pred: intrange::CmpPredicate::ugt, lhs, rhs);
649 return sne && une;
650 }
651 }
652 return false;
653}
654
655std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
656 const ConstantIntRanges &lhs,
657 const ConstantIntRanges &rhs) {
658 if (isStaticallyTrue(pred, lhs, rhs))
659 return true;
660 if (isStaticallyTrue(pred: invertPredicate(pred), lhs, rhs))
661 return false;
662 return std::nullopt;
663}
664

source code of mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp