1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Interfaces/InferIntRangeInterface.h"
11#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
12
13#include "llvm/Support/Debug.h"
14#include <optional>
15
16#define DEBUG_TYPE "int-range-analysis"
17
18using namespace mlir;
19using namespace mlir::arith;
20using namespace mlir::intrange;
21
22static intrange::OverflowFlags
23convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24 intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
25 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26 retFlags |= intrange::OverflowFlags::Nsw;
27 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28 retFlags |= intrange::OverflowFlags::Nuw;
29 return retFlags;
30}
31
32//===----------------------------------------------------------------------===//
33// ConstantOp
34//===----------------------------------------------------------------------===//
35
36void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
37 SetIntRangeFn setResultRange) {
38 if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
39 const APInt &value = scalarCstAttr.getValue();
40 setResultRange(getResult(), ConstantIntRanges::constant(value));
41 return;
42 }
43 if (auto arrayCstAttr =
44 llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45 if (arrayCstAttr.isSplat()) {
46 setResultRange(getResult(), ConstantIntRanges::constant(
47 arrayCstAttr.getSplatValue<APInt>()));
48 return;
49 }
50
51 std::optional<ConstantIntRanges> result;
52 for (const APInt &val : arrayCstAttr) {
53 auto range = ConstantIntRanges::constant(val);
54 result = (result ? result->rangeUnion(range) : range);
55 }
56
57 assert(result && "Zero-sized vectors are not allowed");
58 setResultRange(getResult(), *result);
59 return;
60 }
61}
62
63//===----------------------------------------------------------------------===//
64// AddIOp
65//===----------------------------------------------------------------------===//
66
67void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
68 SetIntRangeFn setResultRange) {
69 setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
70 getOverflowFlags())));
71}
72
73//===----------------------------------------------------------------------===//
74// SubIOp
75//===----------------------------------------------------------------------===//
76
77void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
78 SetIntRangeFn setResultRange) {
79 setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
80 getOverflowFlags())));
81}
82
83//===----------------------------------------------------------------------===//
84// MulIOp
85//===----------------------------------------------------------------------===//
86
87void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88 SetIntRangeFn setResultRange) {
89 setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
90 getOverflowFlags())));
91}
92
93//===----------------------------------------------------------------------===//
94// DivUIOp
95//===----------------------------------------------------------------------===//
96
97void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
98 SetIntRangeFn setResultRange) {
99 setResultRange(getResult(), inferDivU(argRanges));
100}
101
102//===----------------------------------------------------------------------===//
103// DivSIOp
104//===----------------------------------------------------------------------===//
105
106void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
107 SetIntRangeFn setResultRange) {
108 setResultRange(getResult(), inferDivS(argRanges));
109}
110
111//===----------------------------------------------------------------------===//
112// CeilDivUIOp
113//===----------------------------------------------------------------------===//
114
115void arith::CeilDivUIOp::inferResultRanges(
116 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
117 setResultRange(getResult(), inferCeilDivU(argRanges));
118}
119
120//===----------------------------------------------------------------------===//
121// CeilDivSIOp
122//===----------------------------------------------------------------------===//
123
124void arith::CeilDivSIOp::inferResultRanges(
125 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
126 setResultRange(getResult(), inferCeilDivS(argRanges));
127}
128
129//===----------------------------------------------------------------------===//
130// FloorDivSIOp
131//===----------------------------------------------------------------------===//
132
133void arith::FloorDivSIOp::inferResultRanges(
134 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
135 return setResultRange(getResult(), inferFloorDivS(argRanges));
136}
137
138//===----------------------------------------------------------------------===//
139// RemUIOp
140//===----------------------------------------------------------------------===//
141
142void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
143 SetIntRangeFn setResultRange) {
144 setResultRange(getResult(), inferRemU(argRanges));
145}
146
147//===----------------------------------------------------------------------===//
148// RemSIOp
149//===----------------------------------------------------------------------===//
150
151void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152 SetIntRangeFn setResultRange) {
153 setResultRange(getResult(), inferRemS(argRanges));
154}
155
156//===----------------------------------------------------------------------===//
157// AndIOp
158//===----------------------------------------------------------------------===//
159
160void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
161 SetIntRangeFn setResultRange) {
162 setResultRange(getResult(), inferAnd(argRanges));
163}
164
165//===----------------------------------------------------------------------===//
166// OrIOp
167//===----------------------------------------------------------------------===//
168
169void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
170 SetIntRangeFn setResultRange) {
171 setResultRange(getResult(), inferOr(argRanges));
172}
173
174//===----------------------------------------------------------------------===//
175// XOrIOp
176//===----------------------------------------------------------------------===//
177
178void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
179 SetIntRangeFn setResultRange) {
180 setResultRange(getResult(), inferXor(argRanges));
181}
182
183//===----------------------------------------------------------------------===//
184// MaxSIOp
185//===----------------------------------------------------------------------===//
186
187void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
188 SetIntRangeFn setResultRange) {
189 setResultRange(getResult(), inferMaxS(argRanges));
190}
191
192//===----------------------------------------------------------------------===//
193// MaxUIOp
194//===----------------------------------------------------------------------===//
195
196void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
197 SetIntRangeFn setResultRange) {
198 setResultRange(getResult(), inferMaxU(argRanges));
199}
200
201//===----------------------------------------------------------------------===//
202// MinSIOp
203//===----------------------------------------------------------------------===//
204
205void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
206 SetIntRangeFn setResultRange) {
207 setResultRange(getResult(), inferMinS(argRanges));
208}
209
210//===----------------------------------------------------------------------===//
211// MinUIOp
212//===----------------------------------------------------------------------===//
213
214void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
215 SetIntRangeFn setResultRange) {
216 setResultRange(getResult(), inferMinU(argRanges));
217}
218
219//===----------------------------------------------------------------------===//
220// ExtUIOp
221//===----------------------------------------------------------------------===//
222
223void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
224 SetIntRangeFn setResultRange) {
225 unsigned destWidth =
226 ConstantIntRanges::getStorageBitwidth(getResult().getType());
227 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
228}
229
230//===----------------------------------------------------------------------===//
231// ExtSIOp
232//===----------------------------------------------------------------------===//
233
234void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
235 SetIntRangeFn setResultRange) {
236 unsigned destWidth =
237 ConstantIntRanges::getStorageBitwidth(getResult().getType());
238 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
239}
240
241//===----------------------------------------------------------------------===//
242// TruncIOp
243//===----------------------------------------------------------------------===//
244
245void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
246 SetIntRangeFn setResultRange) {
247 unsigned destWidth =
248 ConstantIntRanges::getStorageBitwidth(getResult().getType());
249 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
250}
251
252//===----------------------------------------------------------------------===//
253// IndexCastOp
254//===----------------------------------------------------------------------===//
255
256void arith::IndexCastOp::inferResultRanges(
257 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
258 Type sourceType = getOperand().getType();
259 Type destType = getResult().getType();
260 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
261 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
262
263 if (srcWidth < destWidth)
264 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
265 else if (srcWidth > destWidth)
266 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
267 else
268 setResultRange(getResult(), argRanges[0]);
269}
270
271//===----------------------------------------------------------------------===//
272// IndexCastUIOp
273//===----------------------------------------------------------------------===//
274
275void arith::IndexCastUIOp::inferResultRanges(
276 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
277 Type sourceType = getOperand().getType();
278 Type destType = getResult().getType();
279 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
280 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
281
282 if (srcWidth < destWidth)
283 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
284 else if (srcWidth > destWidth)
285 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
286 else
287 setResultRange(getResult(), argRanges[0]);
288}
289
290//===----------------------------------------------------------------------===//
291// CmpIOp
292//===----------------------------------------------------------------------===//
293
294void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
295 SetIntRangeFn setResultRange) {
296 arith::CmpIPredicate arithPred = getPredicate();
297 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
298 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
299
300 APInt min = APInt::getZero(1);
301 APInt max = APInt::getAllOnes(1);
302
303 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
304 if (truthValue.has_value() && *truthValue)
305 min = max;
306 else if (truthValue.has_value() && !(*truthValue))
307 max = min;
308
309 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
310}
311
312//===----------------------------------------------------------------------===//
313// SelectOp
314//===----------------------------------------------------------------------===//
315
316void arith::SelectOp::inferResultRangesFromOptional(
317 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
318 std::optional<APInt> mbCondVal =
319 argRanges[0].isUninitialized()
320 ? std::nullopt
321 : argRanges[0].getValue().getConstantValue();
322
323 const IntegerValueRange &trueCase = argRanges[1];
324 const IntegerValueRange &falseCase = argRanges[2];
325
326 if (mbCondVal) {
327 if (mbCondVal->isZero())
328 setResultRange(getResult(), falseCase);
329 else
330 setResultRange(getResult(), trueCase);
331 return;
332 }
333 setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
334}
335
336//===----------------------------------------------------------------------===//
337// ShLIOp
338//===----------------------------------------------------------------------===//
339
340void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
341 SetIntRangeFn setResultRange) {
342 setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
343 getOverflowFlags())));
344}
345
346//===----------------------------------------------------------------------===//
347// ShRUIOp
348//===----------------------------------------------------------------------===//
349
350void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
351 SetIntRangeFn setResultRange) {
352 setResultRange(getResult(), inferShrU(argRanges));
353}
354
355//===----------------------------------------------------------------------===//
356// ShRSIOp
357//===----------------------------------------------------------------------===//
358
359void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
360 SetIntRangeFn setResultRange) {
361 setResultRange(getResult(), inferShrS(argRanges));
362}
363

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp