1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Interfaces/InferIntRangeInterface.h"
11#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13#include "llvm/Support/Debug.h"
14#include <optional>
15
16#define DEBUG_TYPE "int-range-analysis"
17
18using namespace mlir;
19using namespace mlir::arith;
20using namespace mlir::intrange;
21
22//===----------------------------------------------------------------------===//
23// ConstantOp
24//===----------------------------------------------------------------------===//
25
26void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27 SetIntRangeFn setResultRange) {
28 auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
29 if (constAttr) {
30 const APInt &value = constAttr.getValue();
31 setResultRange(getResult(), ConstantIntRanges::constant(value));
32 }
33}
34
35//===----------------------------------------------------------------------===//
36// AddIOp
37//===----------------------------------------------------------------------===//
38
39void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
40 SetIntRangeFn setResultRange) {
41 setResultRange(getResult(), inferAdd(argRanges));
42}
43
44//===----------------------------------------------------------------------===//
45// SubIOp
46//===----------------------------------------------------------------------===//
47
48void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
49 SetIntRangeFn setResultRange) {
50 setResultRange(getResult(), inferSub(argRanges));
51}
52
53//===----------------------------------------------------------------------===//
54// MulIOp
55//===----------------------------------------------------------------------===//
56
57void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58 SetIntRangeFn setResultRange) {
59 setResultRange(getResult(), inferMul(argRanges));
60}
61
62//===----------------------------------------------------------------------===//
63// DivUIOp
64//===----------------------------------------------------------------------===//
65
66void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
67 SetIntRangeFn setResultRange) {
68 setResultRange(getResult(), inferDivU(argRanges));
69}
70
71//===----------------------------------------------------------------------===//
72// DivSIOp
73//===----------------------------------------------------------------------===//
74
75void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76 SetIntRangeFn setResultRange) {
77 setResultRange(getResult(), inferDivS(argRanges));
78}
79
80//===----------------------------------------------------------------------===//
81// CeilDivUIOp
82//===----------------------------------------------------------------------===//
83
84void arith::CeilDivUIOp::inferResultRanges(
85 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
86 setResultRange(getResult(), inferCeilDivU(argRanges));
87}
88
89//===----------------------------------------------------------------------===//
90// CeilDivSIOp
91//===----------------------------------------------------------------------===//
92
93void arith::CeilDivSIOp::inferResultRanges(
94 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
95 setResultRange(getResult(), inferCeilDivS(argRanges));
96}
97
98//===----------------------------------------------------------------------===//
99// FloorDivSIOp
100//===----------------------------------------------------------------------===//
101
102void arith::FloorDivSIOp::inferResultRanges(
103 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
104 return setResultRange(getResult(), inferFloorDivS(argRanges));
105}
106
107//===----------------------------------------------------------------------===//
108// RemUIOp
109//===----------------------------------------------------------------------===//
110
111void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112 SetIntRangeFn setResultRange) {
113 setResultRange(getResult(), inferRemU(argRanges));
114}
115
116//===----------------------------------------------------------------------===//
117// RemSIOp
118//===----------------------------------------------------------------------===//
119
120void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
121 SetIntRangeFn setResultRange) {
122 setResultRange(getResult(), inferRemS(argRanges));
123}
124
125//===----------------------------------------------------------------------===//
126// AndIOp
127//===----------------------------------------------------------------------===//
128
129void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130 SetIntRangeFn setResultRange) {
131 setResultRange(getResult(), inferAnd(argRanges));
132}
133
134//===----------------------------------------------------------------------===//
135// OrIOp
136//===----------------------------------------------------------------------===//
137
138void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
139 SetIntRangeFn setResultRange) {
140 setResultRange(getResult(), inferOr(argRanges));
141}
142
143//===----------------------------------------------------------------------===//
144// XOrIOp
145//===----------------------------------------------------------------------===//
146
147void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148 SetIntRangeFn setResultRange) {
149 setResultRange(getResult(), inferXor(argRanges));
150}
151
152//===----------------------------------------------------------------------===//
153// MaxSIOp
154//===----------------------------------------------------------------------===//
155
156void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
157 SetIntRangeFn setResultRange) {
158 setResultRange(getResult(), inferMaxS(argRanges));
159}
160
161//===----------------------------------------------------------------------===//
162// MaxUIOp
163//===----------------------------------------------------------------------===//
164
165void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166 SetIntRangeFn setResultRange) {
167 setResultRange(getResult(), inferMaxU(argRanges));
168}
169
170//===----------------------------------------------------------------------===//
171// MinSIOp
172//===----------------------------------------------------------------------===//
173
174void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
175 SetIntRangeFn setResultRange) {
176 setResultRange(getResult(), inferMinS(argRanges));
177}
178
179//===----------------------------------------------------------------------===//
180// MinUIOp
181//===----------------------------------------------------------------------===//
182
183void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
184 SetIntRangeFn setResultRange) {
185 setResultRange(getResult(), inferMinU(argRanges));
186}
187
188//===----------------------------------------------------------------------===//
189// ExtUIOp
190//===----------------------------------------------------------------------===//
191
192void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
193 SetIntRangeFn setResultRange) {
194 unsigned destWidth =
195 ConstantIntRanges::getStorageBitwidth(getResult().getType());
196 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
197}
198
199//===----------------------------------------------------------------------===//
200// ExtSIOp
201//===----------------------------------------------------------------------===//
202
203void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
204 SetIntRangeFn setResultRange) {
205 unsigned destWidth =
206 ConstantIntRanges::getStorageBitwidth(getResult().getType());
207 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
208}
209
210//===----------------------------------------------------------------------===//
211// TruncIOp
212//===----------------------------------------------------------------------===//
213
214void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
215 SetIntRangeFn setResultRange) {
216 unsigned destWidth =
217 ConstantIntRanges::getStorageBitwidth(getResult().getType());
218 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
219}
220
221//===----------------------------------------------------------------------===//
222// IndexCastOp
223//===----------------------------------------------------------------------===//
224
225void arith::IndexCastOp::inferResultRanges(
226 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
227 Type sourceType = getOperand().getType();
228 Type destType = getResult().getType();
229 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
230 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
231
232 if (srcWidth < destWidth)
233 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
234 else if (srcWidth > destWidth)
235 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
236 else
237 setResultRange(getResult(), argRanges[0]);
238}
239
240//===----------------------------------------------------------------------===//
241// IndexCastUIOp
242//===----------------------------------------------------------------------===//
243
244void arith::IndexCastUIOp::inferResultRanges(
245 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
246 Type sourceType = getOperand().getType();
247 Type destType = getResult().getType();
248 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
249 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
250
251 if (srcWidth < destWidth)
252 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
253 else if (srcWidth > destWidth)
254 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
255 else
256 setResultRange(getResult(), argRanges[0]);
257}
258
259//===----------------------------------------------------------------------===//
260// CmpIOp
261//===----------------------------------------------------------------------===//
262
263void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
264 SetIntRangeFn setResultRange) {
265 arith::CmpIPredicate arithPred = getPredicate();
266 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
267 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
268
269 APInt min = APInt::getZero(1);
270 APInt max = APInt::getAllOnes(1);
271
272 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
273 if (truthValue.has_value() && *truthValue)
274 min = max;
275 else if (truthValue.has_value() && !(*truthValue))
276 max = min;
277
278 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
279}
280
281//===----------------------------------------------------------------------===//
282// SelectOp
283//===----------------------------------------------------------------------===//
284
285void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
286 SetIntRangeFn setResultRange) {
287 std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
288
289 if (mbCondVal) {
290 if (mbCondVal->isZero())
291 setResultRange(getResult(), argRanges[2]);
292 else
293 setResultRange(getResult(), argRanges[1]);
294 return;
295 }
296 setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
297}
298
299//===----------------------------------------------------------------------===//
300// ShLIOp
301//===----------------------------------------------------------------------===//
302
303void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
304 SetIntRangeFn setResultRange) {
305 setResultRange(getResult(), inferShl(argRanges));
306}
307
308//===----------------------------------------------------------------------===//
309// ShRUIOp
310//===----------------------------------------------------------------------===//
311
312void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
313 SetIntRangeFn setResultRange) {
314 setResultRange(getResult(), inferShrU(argRanges));
315}
316
317//===----------------------------------------------------------------------===//
318// ShRSIOp
319//===----------------------------------------------------------------------===//
320
321void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
322 SetIntRangeFn setResultRange) {
323 setResultRange(getResult(), inferShrS(argRanges));
324}
325

source code of mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp