1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Interfaces/InferIntRangeInterface.h"
11#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13#include <optional>
14
15#define DEBUG_TYPE "int-range-analysis"
16
17using namespace mlir;
18using namespace mlir::arith;
19using namespace mlir::intrange;
20
21static intrange::OverflowFlags
22convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
23 intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
24 if (bitEnumContainsAny(bits: flags, bit: arith::IntegerOverflowFlags::nsw))
25 retFlags |= intrange::OverflowFlags::Nsw;
26 if (bitEnumContainsAny(bits: flags, bit: arith::IntegerOverflowFlags::nuw))
27 retFlags |= intrange::OverflowFlags::Nuw;
28 return retFlags;
29}
30
31//===----------------------------------------------------------------------===//
32// ConstantOp
33//===----------------------------------------------------------------------===//
34
35void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
36 SetIntRangeFn setResultRange) {
37 if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(Val: getValue())) {
38 const APInt &value = scalarCstAttr.getValue();
39 setResultRange(getResult(), ConstantIntRanges::constant(value));
40 return;
41 }
42 if (auto arrayCstAttr =
43 llvm::dyn_cast_or_null<DenseIntElementsAttr>(Val: getValue())) {
44 if (arrayCstAttr.isSplat()) {
45 setResultRange(getResult(), ConstantIntRanges::constant(
46 value: arrayCstAttr.getSplatValue<APInt>()));
47 return;
48 }
49
50 std::optional<ConstantIntRanges> result;
51 for (const APInt &val : arrayCstAttr) {
52 auto range = ConstantIntRanges::constant(value: val);
53 result = (result ? result->rangeUnion(other: range) : range);
54 }
55
56 assert(result && "Zero-sized vectors are not allowed");
57 setResultRange(getResult(), *result);
58 return;
59 }
60}
61
62//===----------------------------------------------------------------------===//
63// AddIOp
64//===----------------------------------------------------------------------===//
65
66void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
67 SetIntRangeFn setResultRange) {
68 setResultRange(getResult(), inferAdd(argRanges, ovfFlags: convertArithOverflowFlags(
69 flags: getOverflowFlags())));
70}
71
72//===----------------------------------------------------------------------===//
73// SubIOp
74//===----------------------------------------------------------------------===//
75
76void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
77 SetIntRangeFn setResultRange) {
78 setResultRange(getResult(), inferSub(argRanges, ovfFlags: convertArithOverflowFlags(
79 flags: getOverflowFlags())));
80}
81
82//===----------------------------------------------------------------------===//
83// MulIOp
84//===----------------------------------------------------------------------===//
85
86void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87 SetIntRangeFn setResultRange) {
88 setResultRange(getResult(), inferMul(argRanges, ovfFlags: convertArithOverflowFlags(
89 flags: getOverflowFlags())));
90}
91
92//===----------------------------------------------------------------------===//
93// DivUIOp
94//===----------------------------------------------------------------------===//
95
96void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
97 SetIntRangeFn setResultRange) {
98 setResultRange(getResult(), inferDivU(argRanges));
99}
100
101//===----------------------------------------------------------------------===//
102// DivSIOp
103//===----------------------------------------------------------------------===//
104
105void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106 SetIntRangeFn setResultRange) {
107 setResultRange(getResult(), inferDivS(argRanges));
108}
109
110//===----------------------------------------------------------------------===//
111// CeilDivUIOp
112//===----------------------------------------------------------------------===//
113
114void arith::CeilDivUIOp::inferResultRanges(
115 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
116 setResultRange(getResult(), inferCeilDivU(argRanges));
117}
118
119//===----------------------------------------------------------------------===//
120// CeilDivSIOp
121//===----------------------------------------------------------------------===//
122
123void arith::CeilDivSIOp::inferResultRanges(
124 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
125 setResultRange(getResult(), inferCeilDivS(argRanges));
126}
127
128//===----------------------------------------------------------------------===//
129// FloorDivSIOp
130//===----------------------------------------------------------------------===//
131
132void arith::FloorDivSIOp::inferResultRanges(
133 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
134 return setResultRange(getResult(), inferFloorDivS(argRanges));
135}
136
137//===----------------------------------------------------------------------===//
138// RemUIOp
139//===----------------------------------------------------------------------===//
140
141void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142 SetIntRangeFn setResultRange) {
143 setResultRange(getResult(), inferRemU(argRanges));
144}
145
146//===----------------------------------------------------------------------===//
147// RemSIOp
148//===----------------------------------------------------------------------===//
149
150void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
151 SetIntRangeFn setResultRange) {
152 setResultRange(getResult(), inferRemS(argRanges));
153}
154
155//===----------------------------------------------------------------------===//
156// AndIOp
157//===----------------------------------------------------------------------===//
158
159void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160 SetIntRangeFn setResultRange) {
161 setResultRange(getResult(), inferAnd(argRanges));
162}
163
164//===----------------------------------------------------------------------===//
165// OrIOp
166//===----------------------------------------------------------------------===//
167
168void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
169 SetIntRangeFn setResultRange) {
170 setResultRange(getResult(), inferOr(argRanges));
171}
172
173//===----------------------------------------------------------------------===//
174// XOrIOp
175//===----------------------------------------------------------------------===//
176
177void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
178 SetIntRangeFn setResultRange) {
179 setResultRange(getResult(), inferXor(argRanges));
180}
181
182//===----------------------------------------------------------------------===//
183// MaxSIOp
184//===----------------------------------------------------------------------===//
185
186void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
187 SetIntRangeFn setResultRange) {
188 setResultRange(getResult(), inferMaxS(argRanges));
189}
190
191//===----------------------------------------------------------------------===//
192// MaxUIOp
193//===----------------------------------------------------------------------===//
194
195void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
196 SetIntRangeFn setResultRange) {
197 setResultRange(getResult(), inferMaxU(argRanges));
198}
199
200//===----------------------------------------------------------------------===//
201// MinSIOp
202//===----------------------------------------------------------------------===//
203
204void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
205 SetIntRangeFn setResultRange) {
206 setResultRange(getResult(), inferMinS(argRanges));
207}
208
209//===----------------------------------------------------------------------===//
210// MinUIOp
211//===----------------------------------------------------------------------===//
212
213void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
214 SetIntRangeFn setResultRange) {
215 setResultRange(getResult(), inferMinU(argRanges));
216}
217
218//===----------------------------------------------------------------------===//
219// ExtUIOp
220//===----------------------------------------------------------------------===//
221
222void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
223 SetIntRangeFn setResultRange) {
224 unsigned destWidth =
225 ConstantIntRanges::getStorageBitwidth(type: getResult().getType());
226 setResultRange(getResult(), extUIRange(range: argRanges[0], destWidth));
227}
228
229//===----------------------------------------------------------------------===//
230// ExtSIOp
231//===----------------------------------------------------------------------===//
232
233void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
234 SetIntRangeFn setResultRange) {
235 unsigned destWidth =
236 ConstantIntRanges::getStorageBitwidth(type: getResult().getType());
237 setResultRange(getResult(), extSIRange(range: argRanges[0], destWidth));
238}
239
240//===----------------------------------------------------------------------===//
241// TruncIOp
242//===----------------------------------------------------------------------===//
243
244void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
245 SetIntRangeFn setResultRange) {
246 unsigned destWidth =
247 ConstantIntRanges::getStorageBitwidth(type: getResult().getType());
248 setResultRange(getResult(), truncRange(range: argRanges[0], destWidth));
249}
250
251//===----------------------------------------------------------------------===//
252// IndexCastOp
253//===----------------------------------------------------------------------===//
254
255void arith::IndexCastOp::inferResultRanges(
256 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
257 Type sourceType = getOperand().getType();
258 Type destType = getResult().getType();
259 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType);
260 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType);
261
262 if (srcWidth < destWidth)
263 setResultRange(getResult(), extSIRange(range: argRanges[0], destWidth));
264 else if (srcWidth > destWidth)
265 setResultRange(getResult(), truncRange(range: argRanges[0], destWidth));
266 else
267 setResultRange(getResult(), argRanges[0]);
268}
269
270//===----------------------------------------------------------------------===//
271// IndexCastUIOp
272//===----------------------------------------------------------------------===//
273
274void arith::IndexCastUIOp::inferResultRanges(
275 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
276 Type sourceType = getOperand().getType();
277 Type destType = getResult().getType();
278 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType);
279 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType);
280
281 if (srcWidth < destWidth)
282 setResultRange(getResult(), extUIRange(range: argRanges[0], destWidth));
283 else if (srcWidth > destWidth)
284 setResultRange(getResult(), truncRange(range: argRanges[0], destWidth));
285 else
286 setResultRange(getResult(), argRanges[0]);
287}
288
289//===----------------------------------------------------------------------===//
290// CmpIOp
291//===----------------------------------------------------------------------===//
292
293void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
294 SetIntRangeFn setResultRange) {
295 arith::CmpIPredicate arithPred = getPredicate();
296 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
297 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
298
299 APInt min = APInt::getZero(numBits: 1);
300 APInt max = APInt::getAllOnes(numBits: 1);
301
302 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
303 if (truthValue.has_value() && *truthValue)
304 min = max;
305 else if (truthValue.has_value() && !(*truthValue))
306 max = min;
307
308 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin: min, umax: max));
309}
310
311//===----------------------------------------------------------------------===//
312// SelectOp
313//===----------------------------------------------------------------------===//
314
315void arith::SelectOp::inferResultRangesFromOptional(
316 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
317 std::optional<APInt> mbCondVal =
318 argRanges[0].isUninitialized()
319 ? std::nullopt
320 : argRanges[0].getValue().getConstantValue();
321
322 const IntegerValueRange &trueCase = argRanges[1];
323 const IntegerValueRange &falseCase = argRanges[2];
324
325 if (mbCondVal) {
326 if (mbCondVal->isZero())
327 setResultRange(getResult(), falseCase);
328 else
329 setResultRange(getResult(), trueCase);
330 return;
331 }
332 setResultRange(getResult(), IntegerValueRange::join(lhs: trueCase, rhs: falseCase));
333}
334
335//===----------------------------------------------------------------------===//
336// ShLIOp
337//===----------------------------------------------------------------------===//
338
339void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
340 SetIntRangeFn setResultRange) {
341 setResultRange(getResult(), inferShl(argRanges, ovfFlags: convertArithOverflowFlags(
342 flags: getOverflowFlags())));
343}
344
345//===----------------------------------------------------------------------===//
346// ShRUIOp
347//===----------------------------------------------------------------------===//
348
349void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
350 SetIntRangeFn setResultRange) {
351 setResultRange(getResult(), inferShrU(argRanges));
352}
353
354//===----------------------------------------------------------------------===//
355// ShRSIOp
356//===----------------------------------------------------------------------===//
357
358void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
359 SetIntRangeFn setResultRange) {
360 setResultRange(getResult(), inferShrS(argRanges));
361}
362

source code of mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp