1//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains implementations of range inference for operations that are
10// common to both the `arith` and `index` dialects to facilitate reuse.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
15
16#include "mlir/Interfaces/InferIntRangeInterface.h"
17#include "mlir/Interfaces/ShapedOpInterfaces.h"
18
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/STLExtras.h"
21
22#include "llvm/Support/Debug.h"
23
24#include <iterator>
25#include <optional>
26
27using namespace mlir;
28
29#define DEBUG_TYPE "int-range-analysis"
30
31//===----------------------------------------------------------------------===//
32// General utilities
33//===----------------------------------------------------------------------===//
34
35/// Function that evaluates the result of doing something on arithmetic
36/// constants and returns std::nullopt on overflow.
37using ConstArithFn =
38 function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
39using ConstArithStdFn =
40 std::function<std::optional<APInt>(const APInt &, const APInt &)>;
41
42/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
43/// If either computation overflows, make the result unbounded.
44static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
45 const APInt &minRight,
46 const APInt &maxLeft,
47 const APInt &maxRight, bool isSigned) {
48 std::optional<APInt> maybeMin = op(minLeft, minRight);
49 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50 if (maybeMin && maybeMax)
51 return ConstantIntRanges::range(min: *maybeMin, max: *maybeMax, isSigned);
52 return ConstantIntRanges::maxRange(bitwidth: minLeft.getBitWidth());
53}
54
55/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
56/// ignoring unbounded values. Returns the maximal range if `op` overflows.
57static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
58 ArrayRef<APInt> rhs, bool isSigned) {
59 unsigned width = lhs[0].getBitWidth();
60 APInt min =
61 isSigned ? APInt::getSignedMaxValue(numBits: width) : APInt::getMaxValue(numBits: width);
62 APInt max =
63 isSigned ? APInt::getSignedMinValue(numBits: width) : APInt::getZero(numBits: width);
64 for (const APInt &left : lhs) {
65 for (const APInt &right : rhs) {
66 std::optional<APInt> maybeThisResult = op(left, right);
67 if (!maybeThisResult)
68 return ConstantIntRanges::maxRange(bitwidth: width);
69 APInt result = std::move(*maybeThisResult);
70 min = (isSigned ? result.slt(RHS: min) : result.ult(RHS: min)) ? result : min;
71 max = (isSigned ? result.sgt(RHS: max) : result.ugt(RHS: max)) ? result : max;
72 }
73 }
74 return ConstantIntRanges::range(min, max, isSigned);
75}
76
77//===----------------------------------------------------------------------===//
78// Ext, trunc, index op handling
79//===----------------------------------------------------------------------===//
80
81ConstantIntRanges
82mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
83 ArrayRef<ConstantIntRanges> argRanges,
84 intrange::CmpMode mode) {
85 ConstantIntRanges sixtyFour = inferFn(argRanges);
86 SmallVector<ConstantIntRanges, 2> truncated;
87 llvm::transform(Range&: argRanges, d_first: std::back_inserter(x&: truncated),
88 F: [](const ConstantIntRanges &range) {
89 return truncRange(range, /*destWidth=*/indexMinWidth);
90 });
91 ConstantIntRanges thirtyTwo = inferFn(truncated);
92 ConstantIntRanges thirtyTwoAsSixtyFour =
93 extRange(range: thirtyTwo, /*destWidth=*/indexMaxWidth);
94 ConstantIntRanges sixtyFourAsThirtyTwo =
95 truncRange(range: sixtyFour, /*destWidth=*/indexMinWidth);
96
97 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
98 << " 32-bit = " << thirtyTwo << "\n");
99 bool truncEqual = false;
100 switch (mode) {
101 case intrange::CmpMode::Both:
102 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
103 break;
104 case intrange::CmpMode::Signed:
105 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
106 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
107 break;
108 case intrange::CmpMode::Unsigned:
109 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
110 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
111 break;
112 }
113 if (truncEqual)
114 // Returing the 64-bit result preserves more information.
115 return sixtyFour;
116 ConstantIntRanges merged = sixtyFour.rangeUnion(other: thirtyTwoAsSixtyFour);
117 return merged;
118}
119
120ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
121 unsigned int destWidth) {
122 APInt umin = range.umin().zext(width: destWidth);
123 APInt umax = range.umax().zext(width: destWidth);
124 APInt smin = range.smin().sext(width: destWidth);
125 APInt smax = range.smax().sext(width: destWidth);
126 return {umin, umax, smin, smax};
127}
128
129ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
130 unsigned destWidth) {
131 APInt umin = range.umin().zext(width: destWidth);
132 APInt umax = range.umax().zext(width: destWidth);
133 return ConstantIntRanges::fromUnsigned(umin, umax);
134}
135
136ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
137 unsigned destWidth) {
138 APInt smin = range.smin().sext(width: destWidth);
139 APInt smax = range.smax().sext(width: destWidth);
140 return ConstantIntRanges::fromSigned(smin, smax);
141}
142
143ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
144 unsigned int destWidth) {
145 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
146 // the range of the resulting value is not contiguous ind includes 0.
147 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
148 // but you can't truncate [255, 257] similarly.
149 bool hasUnsignedRollover =
150 range.umin().lshr(shiftAmt: destWidth) != range.umax().lshr(shiftAmt: destWidth);
151 APInt umin = hasUnsignedRollover ? APInt::getZero(numBits: destWidth)
152 : range.umin().trunc(width: destWidth);
153 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(numBits: destWidth)
154 : range.umax().trunc(width: destWidth);
155
156 // Signed post-truncation rollover will not occur when either:
157 // - The high parts of the min and max, plus the sign bit, are the same
158 // - The high halves + sign bit of the min and max are either all 1s or all 0s
159 // and you won't create a [positive, negative] range by truncating.
160 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
161 // but not [255, 257]_i16 to a range of i8s. You can also truncate
162 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
163 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
164 // will truncate to 0x7e, which is greater than 0
165 APInt sminHighPart = range.smin().ashr(ShiftAmt: destWidth - 1);
166 APInt smaxHighPart = range.smax().ashr(ShiftAmt: destWidth - 1);
167 bool hasSignedOverflow =
168 (sminHighPart != smaxHighPart) &&
169 !(sminHighPart.isAllOnes() &&
170 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171 !(sminHighPart.isZero() && smaxHighPart.isZero());
172 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(numBits: destWidth)
173 : range.smin().trunc(width: destWidth);
174 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(numBits: destWidth)
175 : range.smax().trunc(width: destWidth);
176 return {umin, umax, smin, smax};
177}
178
179//===----------------------------------------------------------------------===//
180// Addition
181//===----------------------------------------------------------------------===//
182
183ConstantIntRanges
184mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
185 OverflowFlags ovfFlags) {
186 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
187
188 ConstArithStdFn uadd = [=](const APInt &a,
189 const APInt &b) -> std::optional<APInt> {
190 bool overflowed = false;
191 APInt result = any(Val: ovfFlags & OverflowFlags::Nuw)
192 ? a.uadd_sat(RHS: b)
193 : a.uadd_ov(RHS: b, Overflow&: overflowed);
194 return overflowed ? std::optional<APInt>() : result;
195 };
196 ConstArithStdFn sadd = [=](const APInt &a,
197 const APInt &b) -> std::optional<APInt> {
198 bool overflowed = false;
199 APInt result = any(Val: ovfFlags & OverflowFlags::Nsw)
200 ? a.sadd_sat(RHS: b)
201 : a.sadd_ov(RHS: b, Overflow&: overflowed);
202 return overflowed ? std::optional<APInt>() : result;
203 };
204
205 ConstantIntRanges urange = computeBoundsBy(
206 op: uadd, minLeft: lhs.umin(), minRight: rhs.umin(), maxLeft: lhs.umax(), maxRight: rhs.umax(), /*isSigned=*/false);
207 ConstantIntRanges srange = computeBoundsBy(
208 op: sadd, minLeft: lhs.smin(), minRight: rhs.smin(), maxLeft: lhs.smax(), maxRight: rhs.smax(), /*isSigned=*/true);
209 return urange.intersection(other: srange);
210}
211
212//===----------------------------------------------------------------------===//
213// Subtraction
214//===----------------------------------------------------------------------===//
215
216ConstantIntRanges
217mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
218 OverflowFlags ovfFlags) {
219 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
220
221 ConstArithStdFn usub = [=](const APInt &a,
222 const APInt &b) -> std::optional<APInt> {
223 bool overflowed = false;
224 APInt result = any(Val: ovfFlags & OverflowFlags::Nuw)
225 ? a.usub_sat(RHS: b)
226 : a.usub_ov(RHS: b, Overflow&: overflowed);
227 return overflowed ? std::optional<APInt>() : result;
228 };
229 ConstArithStdFn ssub = [=](const APInt &a,
230 const APInt &b) -> std::optional<APInt> {
231 bool overflowed = false;
232 APInt result = any(Val: ovfFlags & OverflowFlags::Nsw)
233 ? a.ssub_sat(RHS: b)
234 : a.ssub_ov(RHS: b, Overflow&: overflowed);
235 return overflowed ? std::optional<APInt>() : result;
236 };
237 ConstantIntRanges urange = computeBoundsBy(
238 op: usub, minLeft: lhs.umin(), minRight: rhs.umax(), maxLeft: lhs.umax(), maxRight: rhs.umin(), /*isSigned=*/false);
239 ConstantIntRanges srange = computeBoundsBy(
240 op: ssub, minLeft: lhs.smin(), minRight: rhs.smax(), maxLeft: lhs.smax(), maxRight: rhs.smin(), /*isSigned=*/true);
241 return urange.intersection(other: srange);
242}
243
244//===----------------------------------------------------------------------===//
245// Multiplication
246//===----------------------------------------------------------------------===//
247
248ConstantIntRanges
249mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
250 OverflowFlags ovfFlags) {
251 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
252
253 ConstArithStdFn umul = [=](const APInt &a,
254 const APInt &b) -> std::optional<APInt> {
255 bool overflowed = false;
256 APInt result = any(Val: ovfFlags & OverflowFlags::Nuw)
257 ? a.umul_sat(RHS: b)
258 : a.umul_ov(RHS: b, Overflow&: overflowed);
259 return overflowed ? std::optional<APInt>() : result;
260 };
261 ConstArithStdFn smul = [=](const APInt &a,
262 const APInt &b) -> std::optional<APInt> {
263 bool overflowed = false;
264 APInt result = any(Val: ovfFlags & OverflowFlags::Nsw)
265 ? a.smul_sat(RHS: b)
266 : a.smul_ov(RHS: b, Overflow&: overflowed);
267 return overflowed ? std::optional<APInt>() : result;
268 };
269
270 ConstantIntRanges urange =
271 minMaxBy(op: umul, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()},
272 /*isSigned=*/false);
273 ConstantIntRanges srange =
274 minMaxBy(op: smul, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.smin(), rhs.smax()},
275 /*isSigned=*/true);
276 return urange.intersection(other: srange);
277}
278
279//===----------------------------------------------------------------------===//
280// DivU, CeilDivU (Unsigned division)
281//===----------------------------------------------------------------------===//
282
283/// Fix up division results (ex. for ceiling and floor), returning an APInt
284/// if there has been no overflow
285using DivisionFixupFn = function_ref<std::optional<APInt>(
286 const APInt &lhs, const APInt &rhs, const APInt &result)>;
287
288static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
289 const ConstantIntRanges &rhs,
290 DivisionFixupFn fixup) {
291 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
292 &rhsMax = rhs.umax();
293
294 if (!rhsMin.isZero()) {
295 auto udiv = [&fixup](const APInt &a,
296 const APInt &b) -> std::optional<APInt> {
297 return fixup(a, b, a.udiv(RHS: b));
298 };
299 return minMaxBy(op: udiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax},
300 /*isSigned=*/false);
301 }
302
303 APInt umin = APInt::getZero(numBits: rhsMin.getBitWidth());
304 if (lhsMin.uge(RHS: rhsMax) && !rhsMax.isZero())
305 umin = lhsMin.udiv(RHS: rhsMax);
306
307 // X u/ Y u<= X.
308 APInt umax = lhsMax;
309 return ConstantIntRanges::fromUnsigned(umin, umax);
310}
311
312ConstantIntRanges
313mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
314 return inferDivURange(lhs: argRanges[0], rhs: argRanges[1],
315 fixup: [](const APInt &lhs, const APInt &rhs,
316 const APInt &result) { return result; });
317}
318
319ConstantIntRanges
320mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
321 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
322
323 auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
324 const APInt &result) -> std::optional<APInt> {
325 if (!lhs.urem(RHS: rhs).isZero()) {
326 bool overflowed = false;
327 APInt corrected =
328 result.uadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
329 return overflowed ? std::optional<APInt>() : corrected;
330 }
331 return result;
332 };
333 return inferDivURange(lhs, rhs, fixup: ceilDivUIFix);
334}
335
336//===----------------------------------------------------------------------===//
337// DivS, CeilDivS, FloorDivS (Signed division)
338//===----------------------------------------------------------------------===//
339
340static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
341 const ConstantIntRanges &rhs,
342 DivisionFixupFn fixup) {
343 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
344 &rhsMax = rhs.smax();
345 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
346
347 if (canDivide) {
348 auto sdiv = [&fixup](const APInt &a,
349 const APInt &b) -> std::optional<APInt> {
350 bool overflowed = false;
351 APInt result = a.sdiv_ov(RHS: b, Overflow&: overflowed);
352 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
353 };
354 return minMaxBy(op: sdiv, lhs: {lhsMin, lhsMax}, rhs: {rhsMin, rhsMax},
355 /*isSigned=*/true);
356 }
357 return ConstantIntRanges::maxRange(bitwidth: rhsMin.getBitWidth());
358}
359
360ConstantIntRanges
361mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
362 return inferDivSRange(lhs: argRanges[0], rhs: argRanges[1],
363 fixup: [](const APInt &lhs, const APInt &rhs,
364 const APInt &result) { return result; });
365}
366
367ConstantIntRanges
368mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
369 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
370
371 auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
372 const APInt &result) -> std::optional<APInt> {
373 if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
374 bool overflowed = false;
375 APInt corrected =
376 result.sadd_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
377 return overflowed ? std::optional<APInt>() : corrected;
378 }
379 // Special case where the usual implementation of ceilDiv causes
380 // INT_MIN / [positive number] to be positive. This doesn't match the
381 // definition of signed ceiling division mathematically, but it prevents
382 // inconsistent constant-folding results. This arises because (-int_min) is
383 // still negative, so -(-int_min / b) is -(int_min / b), which is
384 // positive See #115293.
385 if (lhs.isMinSignedValue() && rhs.sgt(RHS: 1)) {
386 return -result;
387 }
388 return result;
389 };
390 ConstantIntRanges result = inferDivSRange(lhs, rhs, fixup: ceilDivSIFix);
391 if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(RHS: lhs.smin())) {
392 // If lhs range includes INT_MIN and lhs is not a single value, we can
393 // suddenly wrap to positive val, skipping entire negative range, add
394 // [INT_MIN + 1, smax()] range to the result to handle this.
395 auto newLhs = ConstantIntRanges::fromSigned(smin: lhs.smin() + 1, smax: lhs.smax());
396 result = result.rangeUnion(other: inferDivSRange(lhs: newLhs, rhs, fixup: ceilDivSIFix));
397 }
398 return result;
399}
400
401ConstantIntRanges
402mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
403 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
404
405 auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
406 const APInt &result) -> std::optional<APInt> {
407 if (!lhs.srem(RHS: rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
408 bool overflowed = false;
409 APInt corrected =
410 result.ssub_ov(RHS: APInt(result.getBitWidth(), 1), Overflow&: overflowed);
411 return overflowed ? std::optional<APInt>() : corrected;
412 }
413 return result;
414 };
415 return inferDivSRange(lhs, rhs, fixup: floorDivSIFix);
416}
417
418//===----------------------------------------------------------------------===//
419// Signed remainder (RemS)
420//===----------------------------------------------------------------------===//
421
422ConstantIntRanges
423mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
424 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
425 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
426 &rhsMax = rhs.smax();
427
428 unsigned width = rhsMax.getBitWidth();
429 APInt smin = APInt::getSignedMinValue(numBits: width);
430 APInt smax = APInt::getSignedMaxValue(numBits: width);
431 // No bounds if zero could be a divisor.
432 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
433 if (canBound) {
434 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
435 bool canNegativeDividend = lhsMin.isNegative();
436 bool canPositiveDividend = lhsMax.isStrictlyPositive();
437 APInt zero = APInt::getZero(numBits: maxDivisor.getBitWidth());
438 APInt maxPositiveResult = maxDivisor - 1;
439 APInt minNegativeResult = -maxPositiveResult;
440 smin = canNegativeDividend ? minNegativeResult : zero;
441 smax = canPositiveDividend ? maxPositiveResult : zero;
442 // Special case: sweeping out a contiguous range in N/[modulus].
443 if (rhsMin == rhsMax) {
444 if ((lhsMax - lhsMin).ult(RHS: maxDivisor)) {
445 APInt minRem = lhsMin.srem(RHS: maxDivisor);
446 APInt maxRem = lhsMax.srem(RHS: maxDivisor);
447 if (minRem.sle(RHS: maxRem)) {
448 smin = minRem;
449 smax = maxRem;
450 }
451 }
452 }
453 }
454 return ConstantIntRanges::fromSigned(smin, smax);
455}
456
457//===----------------------------------------------------------------------===//
458// Unsigned remainder (RemU)
459//===----------------------------------------------------------------------===//
460
461ConstantIntRanges
462mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
463 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
464 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
465
466 unsigned width = rhsMin.getBitWidth();
467 APInt umin = APInt::getZero(numBits: width);
468 // Remainder can't be larger than either of its arguments.
469 APInt umax = llvm::APIntOps::umin(A: (rhsMax - 1), B: lhs.umax());
470
471 if (!rhsMin.isZero()) {
472 // Special case: sweeping out a contiguous range in N/[modulus]
473 if (rhsMin == rhsMax) {
474 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
475 if ((lhsMax - lhsMin).ult(RHS: rhsMax)) {
476 APInt minRem = lhsMin.urem(RHS: rhsMax);
477 APInt maxRem = lhsMax.urem(RHS: rhsMax);
478 if (minRem.ule(RHS: maxRem)) {
479 umin = minRem;
480 umax = maxRem;
481 }
482 }
483 }
484 }
485 return ConstantIntRanges::fromUnsigned(umin, umax);
486}
487
488//===----------------------------------------------------------------------===//
489// Max and min (MaxS, MaxU, MinS, MinU)
490//===----------------------------------------------------------------------===//
491
492ConstantIntRanges
493mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
494 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
495
496 const APInt &smin = lhs.smin().sgt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin();
497 const APInt &smax = lhs.smax().sgt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax();
498 return ConstantIntRanges::fromSigned(smin, smax);
499}
500
501ConstantIntRanges
502mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
503 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
504
505 const APInt &umin = lhs.umin().ugt(RHS: rhs.umin()) ? lhs.umin() : rhs.umin();
506 const APInt &umax = lhs.umax().ugt(RHS: rhs.umax()) ? lhs.umax() : rhs.umax();
507 return ConstantIntRanges::fromUnsigned(umin, umax);
508}
509
510ConstantIntRanges
511mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
512 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
513
514 const APInt &smin = lhs.smin().slt(RHS: rhs.smin()) ? lhs.smin() : rhs.smin();
515 const APInt &smax = lhs.smax().slt(RHS: rhs.smax()) ? lhs.smax() : rhs.smax();
516 return ConstantIntRanges::fromSigned(smin, smax);
517}
518
519ConstantIntRanges
520mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
521 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
522
523 const APInt &umin = lhs.umin().ult(RHS: rhs.umin()) ? lhs.umin() : rhs.umin();
524 const APInt &umax = lhs.umax().ult(RHS: rhs.umax()) ? lhs.umax() : rhs.umax();
525 return ConstantIntRanges::fromUnsigned(umin, umax);
526}
527
528//===----------------------------------------------------------------------===//
529// Bitwise operators (And, Or, Xor)
530//===----------------------------------------------------------------------===//
531
532/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
533/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
534/// that both bonuds have in common. This gives us a consertive approximation
535/// for what values can be passed to bitwise operations.
536static std::tuple<APInt, APInt>
537widenBitwiseBounds(const ConstantIntRanges &bound) {
538 APInt leftVal = bound.umin(), rightVal = bound.umax();
539 unsigned bitwidth = leftVal.getBitWidth();
540 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
541 leftVal.clearLowBits(loBits: differingBits);
542 rightVal.setLowBits(differingBits);
543 return std::make_tuple(args: std::move(leftVal), args: std::move(rightVal));
544}
545
546ConstantIntRanges
547mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
548 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]);
549 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]);
550 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
551 return a & b;
552 };
553 return minMaxBy(op: andi, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes},
554 /*isSigned=*/false);
555}
556
557ConstantIntRanges
558mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
559 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(bound: argRanges[0]);
560 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(bound: argRanges[1]);
561 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
562 return a | b;
563 };
564 return minMaxBy(op: ori, lhs: {lhsZeros, lhsOnes}, rhs: {rhsZeros, rhsOnes},
565 /*isSigned=*/false);
566}
567
568/// Get bitmask of all bits which can change while iterating in
569/// [bound.umin(), bound.umax()].
570static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
571 APInt leftVal = bound.umin(), rightVal = bound.umax();
572 unsigned bitwidth = leftVal.getBitWidth();
573 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
574 return APInt::getLowBitsSet(numBits: bitwidth, loBitsSet: differingBits);
575}
576
577ConstantIntRanges
578mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
579 // Construct mask of varying bits for both ranges, xor values and then replace
580 // masked bits with 0s and 1s to get min and max values respectively.
581 ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
582 APInt mask = getVaryingBitsMask(bound: lhs) | getVaryingBitsMask(bound: rhs);
583 APInt res = lhs.umin() ^ rhs.umin();
584 APInt min = res & ~mask;
585 APInt max = res | mask;
586 return ConstantIntRanges::fromUnsigned(umin: min, umax: max);
587}
588
589//===----------------------------------------------------------------------===//
590// Shifts (Shl, ShrS, ShrU)
591//===----------------------------------------------------------------------===//
592
593ConstantIntRanges
594mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
595 OverflowFlags ovfFlags) {
596 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
597 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
598
599 // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
600 // 2^rhs.
601 ConstArithStdFn ushl = [=](const APInt &l,
602 const APInt &r) -> std::optional<APInt> {
603 bool overflowed = false;
604 APInt result = any(Val: ovfFlags & OverflowFlags::Nuw)
605 ? l.ushl_sat(RHS: r)
606 : l.ushl_ov(Amt: r, Overflow&: overflowed);
607 return overflowed ? std::optional<APInt>() : result;
608 };
609 ConstArithStdFn sshl = [=](const APInt &l,
610 const APInt &r) -> std::optional<APInt> {
611 bool overflowed = false;
612 APInt result = any(Val: ovfFlags & OverflowFlags::Nsw)
613 ? l.sshl_sat(RHS: r)
614 : l.sshl_ov(Amt: r, Overflow&: overflowed);
615 return overflowed ? std::optional<APInt>() : result;
616 };
617
618 ConstantIntRanges urange =
619 minMaxBy(op: ushl, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhsUMin, rhsUMax},
620 /*isSigned=*/false);
621 ConstantIntRanges srange =
622 minMaxBy(op: sshl, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhsUMin, rhsUMax},
623 /*isSigned=*/true);
624 return urange.intersection(other: srange);
625}
626
627ConstantIntRanges
628mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
629 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
630
631 auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
632 return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.ashr(ShiftAmt: r);
633 };
634
635 return minMaxBy(op: ashr, lhs: {lhs.smin(), lhs.smax()}, rhs: {rhs.umin(), rhs.umax()},
636 /*isSigned=*/true);
637}
638
639ConstantIntRanges
640mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
641 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
642
643 auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
644 return r.uge(RHS: r.getBitWidth()) ? std::optional<APInt>() : l.lshr(ShiftAmt: r);
645 };
646 return minMaxBy(op: lshr, lhs: {lhs.umin(), lhs.umax()}, rhs: {rhs.umin(), rhs.umax()},
647 /*isSigned=*/false);
648}
649
650//===----------------------------------------------------------------------===//
651// Comparisons (Cmp)
652//===----------------------------------------------------------------------===//
653
654static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
655 switch (pred) {
656 case intrange::CmpPredicate::eq:
657 return intrange::CmpPredicate::ne;
658 case intrange::CmpPredicate::ne:
659 return intrange::CmpPredicate::eq;
660 case intrange::CmpPredicate::slt:
661 return intrange::CmpPredicate::sge;
662 case intrange::CmpPredicate::sle:
663 return intrange::CmpPredicate::sgt;
664 case intrange::CmpPredicate::sgt:
665 return intrange::CmpPredicate::sle;
666 case intrange::CmpPredicate::sge:
667 return intrange::CmpPredicate::slt;
668 case intrange::CmpPredicate::ult:
669 return intrange::CmpPredicate::uge;
670 case intrange::CmpPredicate::ule:
671 return intrange::CmpPredicate::ugt;
672 case intrange::CmpPredicate::ugt:
673 return intrange::CmpPredicate::ule;
674 case intrange::CmpPredicate::uge:
675 return intrange::CmpPredicate::ult;
676 }
677 llvm_unreachable("unknown cmp predicate value");
678}
679
680static bool isStaticallyTrue(intrange::CmpPredicate pred,
681 const ConstantIntRanges &lhs,
682 const ConstantIntRanges &rhs) {
683 switch (pred) {
684 case intrange::CmpPredicate::sle:
685 return lhs.smax().sle(RHS: rhs.smin());
686 case intrange::CmpPredicate::slt:
687 return lhs.smax().slt(RHS: rhs.smin());
688 case intrange::CmpPredicate::ule:
689 return lhs.umax().ule(RHS: rhs.umin());
690 case intrange::CmpPredicate::ult:
691 return lhs.umax().ult(RHS: rhs.umin());
692 case intrange::CmpPredicate::sge:
693 return lhs.smin().sge(RHS: rhs.smax());
694 case intrange::CmpPredicate::sgt:
695 return lhs.smin().sgt(RHS: rhs.smax());
696 case intrange::CmpPredicate::uge:
697 return lhs.umin().uge(RHS: rhs.umax());
698 case intrange::CmpPredicate::ugt:
699 return lhs.umin().ugt(RHS: rhs.umax());
700 case intrange::CmpPredicate::eq: {
701 std::optional<APInt> lhsConst = lhs.getConstantValue();
702 std::optional<APInt> rhsConst = rhs.getConstantValue();
703 return lhsConst && rhsConst && lhsConst == rhsConst;
704 }
705 case intrange::CmpPredicate::ne: {
706 // While equality requires that there is an interpration of the preceeding
707 // computations that produces equal constants, whether that be signed or
708 // unsigned, statically determining inequality requires that neither
709 // interpretation produce potentially overlapping ranges.
710 bool sne = isStaticallyTrue(pred: intrange::CmpPredicate::slt, lhs, rhs) ||
711 isStaticallyTrue(pred: intrange::CmpPredicate::sgt, lhs, rhs);
712 bool une = isStaticallyTrue(pred: intrange::CmpPredicate::ult, lhs, rhs) ||
713 isStaticallyTrue(pred: intrange::CmpPredicate::ugt, lhs, rhs);
714 return sne && une;
715 }
716 }
717 return false;
718}
719
720std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
721 const ConstantIntRanges &lhs,
722 const ConstantIntRanges &rhs) {
723 if (isStaticallyTrue(pred, lhs, rhs))
724 return true;
725 if (isStaticallyTrue(pred: invertPredicate(pred), lhs, rhs))
726 return false;
727 return std::nullopt;
728}
729
730//===----------------------------------------------------------------------===//
731// Shaped type dimension accessors / ShapedDimOpInterface
732//===----------------------------------------------------------------------===//
733
734ConstantIntRanges
735mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
736 const IntegerValueRange &maybeDim) {
737 unsigned width =
738 ConstantIntRanges::getStorageBitwidth(type: op->getResult(0).getType());
739 APInt zero = APInt::getZero(numBits: width);
740 APInt typeMax = APInt::getSignedMaxValue(numBits: width);
741
742 auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
743 if (!shapedTy.hasRank())
744 return ConstantIntRanges::fromSigned(smin: zero, smax: typeMax);
745
746 int64_t rank = shapedTy.getRank();
747 int64_t minDim = 0;
748 int64_t maxDim = rank - 1;
749 if (!maybeDim.isUninitialized()) {
750 const ConstantIntRanges &dim = maybeDim.getValue();
751 minDim = std::max(a: minDim, b: dim.smin().getSExtValue());
752 maxDim = std::min(a: maxDim, b: dim.smax().getSExtValue());
753 }
754
755 std::optional<ConstantIntRanges> result;
756 auto joinResult = [&](const ConstantIntRanges &thisResult) {
757 if (!result.has_value())
758 result = thisResult;
759 else
760 result = result->rangeUnion(other: thisResult);
761 };
762 for (int64_t i = minDim; i <= maxDim; ++i) {
763 int64_t length = shapedTy.getDimSize(i);
764
765 if (ShapedType::isDynamic(length))
766 joinResult(ConstantIntRanges::fromSigned(smin: zero, smax: typeMax));
767 else
768 joinResult(ConstantIntRanges::constant(value: APInt(width, length)));
769 }
770 return result.value_or(u: ConstantIntRanges::fromSigned(smin: zero, smax: typeMax));
771}
772

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp