1 | //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
11 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
12 | |
13 | #include "llvm/Support/Debug.h" |
14 | #include <optional> |
15 | |
16 | #define DEBUG_TYPE "int-range-analysis" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::arith; |
20 | using namespace mlir::intrange; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // ConstantOp |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
27 | SetIntRangeFn setResultRange) { |
28 | auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue()); |
29 | if (constAttr) { |
30 | const APInt &value = constAttr.getValue(); |
31 | setResultRange(getResult(), ConstantIntRanges::constant(value)); |
32 | } |
33 | } |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // AddIOp |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
40 | SetIntRangeFn setResultRange) { |
41 | setResultRange(getResult(), inferAdd(argRanges)); |
42 | } |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // SubIOp |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
49 | SetIntRangeFn setResultRange) { |
50 | setResultRange(getResult(), inferSub(argRanges)); |
51 | } |
52 | |
53 | //===----------------------------------------------------------------------===// |
54 | // MulIOp |
55 | //===----------------------------------------------------------------------===// |
56 | |
57 | void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
58 | SetIntRangeFn setResultRange) { |
59 | setResultRange(getResult(), inferMul(argRanges)); |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // DivUIOp |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
67 | SetIntRangeFn setResultRange) { |
68 | setResultRange(getResult(), inferDivU(argRanges)); |
69 | } |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | // DivSIOp |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
76 | SetIntRangeFn setResultRange) { |
77 | setResultRange(getResult(), inferDivS(argRanges)); |
78 | } |
79 | |
80 | //===----------------------------------------------------------------------===// |
81 | // CeilDivUIOp |
82 | //===----------------------------------------------------------------------===// |
83 | |
84 | void arith::CeilDivUIOp::inferResultRanges( |
85 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { |
86 | setResultRange(getResult(), inferCeilDivU(argRanges)); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // CeilDivSIOp |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | void arith::CeilDivSIOp::inferResultRanges( |
94 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { |
95 | setResultRange(getResult(), inferCeilDivS(argRanges)); |
96 | } |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // FloorDivSIOp |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | void arith::FloorDivSIOp::inferResultRanges( |
103 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { |
104 | return setResultRange(getResult(), inferFloorDivS(argRanges)); |
105 | } |
106 | |
107 | //===----------------------------------------------------------------------===// |
108 | // RemUIOp |
109 | //===----------------------------------------------------------------------===// |
110 | |
111 | void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
112 | SetIntRangeFn setResultRange) { |
113 | setResultRange(getResult(), inferRemU(argRanges)); |
114 | } |
115 | |
116 | //===----------------------------------------------------------------------===// |
117 | // RemSIOp |
118 | //===----------------------------------------------------------------------===// |
119 | |
120 | void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
121 | SetIntRangeFn setResultRange) { |
122 | setResultRange(getResult(), inferRemS(argRanges)); |
123 | } |
124 | |
125 | //===----------------------------------------------------------------------===// |
126 | // AndIOp |
127 | //===----------------------------------------------------------------------===// |
128 | |
129 | void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
130 | SetIntRangeFn setResultRange) { |
131 | setResultRange(getResult(), inferAnd(argRanges)); |
132 | } |
133 | |
134 | //===----------------------------------------------------------------------===// |
135 | // OrIOp |
136 | //===----------------------------------------------------------------------===// |
137 | |
138 | void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
139 | SetIntRangeFn setResultRange) { |
140 | setResultRange(getResult(), inferOr(argRanges)); |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // XOrIOp |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
148 | SetIntRangeFn setResultRange) { |
149 | setResultRange(getResult(), inferXor(argRanges)); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // MaxSIOp |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
157 | SetIntRangeFn setResultRange) { |
158 | setResultRange(getResult(), inferMaxS(argRanges)); |
159 | } |
160 | |
161 | //===----------------------------------------------------------------------===// |
162 | // MaxUIOp |
163 | //===----------------------------------------------------------------------===// |
164 | |
165 | void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
166 | SetIntRangeFn setResultRange) { |
167 | setResultRange(getResult(), inferMaxU(argRanges)); |
168 | } |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // MinSIOp |
172 | //===----------------------------------------------------------------------===// |
173 | |
174 | void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
175 | SetIntRangeFn setResultRange) { |
176 | setResultRange(getResult(), inferMinS(argRanges)); |
177 | } |
178 | |
179 | //===----------------------------------------------------------------------===// |
180 | // MinUIOp |
181 | //===----------------------------------------------------------------------===// |
182 | |
183 | void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
184 | SetIntRangeFn setResultRange) { |
185 | setResultRange(getResult(), inferMinU(argRanges)); |
186 | } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // ExtUIOp |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
193 | SetIntRangeFn setResultRange) { |
194 | unsigned destWidth = |
195 | ConstantIntRanges::getStorageBitwidth(getResult().getType()); |
196 | setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); |
197 | } |
198 | |
199 | //===----------------------------------------------------------------------===// |
200 | // ExtSIOp |
201 | //===----------------------------------------------------------------------===// |
202 | |
203 | void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
204 | SetIntRangeFn setResultRange) { |
205 | unsigned destWidth = |
206 | ConstantIntRanges::getStorageBitwidth(getResult().getType()); |
207 | setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); |
208 | } |
209 | |
210 | //===----------------------------------------------------------------------===// |
211 | // TruncIOp |
212 | //===----------------------------------------------------------------------===// |
213 | |
214 | void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
215 | SetIntRangeFn setResultRange) { |
216 | unsigned destWidth = |
217 | ConstantIntRanges::getStorageBitwidth(getResult().getType()); |
218 | setResultRange(getResult(), truncRange(argRanges[0], destWidth)); |
219 | } |
220 | |
221 | //===----------------------------------------------------------------------===// |
222 | // IndexCastOp |
223 | //===----------------------------------------------------------------------===// |
224 | |
225 | void arith::IndexCastOp::inferResultRanges( |
226 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { |
227 | Type sourceType = getOperand().getType(); |
228 | Type destType = getResult().getType(); |
229 | unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); |
230 | unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); |
231 | |
232 | if (srcWidth < destWidth) |
233 | setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); |
234 | else if (srcWidth > destWidth) |
235 | setResultRange(getResult(), truncRange(argRanges[0], destWidth)); |
236 | else |
237 | setResultRange(getResult(), argRanges[0]); |
238 | } |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | // IndexCastUIOp |
242 | //===----------------------------------------------------------------------===// |
243 | |
244 | void arith::IndexCastUIOp::inferResultRanges( |
245 | ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { |
246 | Type sourceType = getOperand().getType(); |
247 | Type destType = getResult().getType(); |
248 | unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); |
249 | unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); |
250 | |
251 | if (srcWidth < destWidth) |
252 | setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); |
253 | else if (srcWidth > destWidth) |
254 | setResultRange(getResult(), truncRange(argRanges[0], destWidth)); |
255 | else |
256 | setResultRange(getResult(), argRanges[0]); |
257 | } |
258 | |
259 | //===----------------------------------------------------------------------===// |
260 | // CmpIOp |
261 | //===----------------------------------------------------------------------===// |
262 | |
263 | void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
264 | SetIntRangeFn setResultRange) { |
265 | arith::CmpIPredicate arithPred = getPredicate(); |
266 | intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred); |
267 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
268 | |
269 | APInt min = APInt::getZero(1); |
270 | APInt max = APInt::getAllOnes(1); |
271 | |
272 | std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs); |
273 | if (truthValue.has_value() && *truthValue) |
274 | min = max; |
275 | else if (truthValue.has_value() && !(*truthValue)) |
276 | max = min; |
277 | |
278 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); |
279 | } |
280 | |
281 | //===----------------------------------------------------------------------===// |
282 | // SelectOp |
283 | //===----------------------------------------------------------------------===// |
284 | |
285 | void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
286 | SetIntRangeFn setResultRange) { |
287 | std::optional<APInt> mbCondVal = argRanges[0].getConstantValue(); |
288 | |
289 | if (mbCondVal) { |
290 | if (mbCondVal->isZero()) |
291 | setResultRange(getResult(), argRanges[2]); |
292 | else |
293 | setResultRange(getResult(), argRanges[1]); |
294 | return; |
295 | } |
296 | setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); |
297 | } |
298 | |
299 | //===----------------------------------------------------------------------===// |
300 | // ShLIOp |
301 | //===----------------------------------------------------------------------===// |
302 | |
303 | void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
304 | SetIntRangeFn setResultRange) { |
305 | setResultRange(getResult(), inferShl(argRanges)); |
306 | } |
307 | |
308 | //===----------------------------------------------------------------------===// |
309 | // ShRUIOp |
310 | //===----------------------------------------------------------------------===// |
311 | |
312 | void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
313 | SetIntRangeFn setResultRange) { |
314 | setResultRange(getResult(), inferShrU(argRanges)); |
315 | } |
316 | |
317 | //===----------------------------------------------------------------------===// |
318 | // ShRSIOp |
319 | //===----------------------------------------------------------------------===// |
320 | |
321 | void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
322 | SetIntRangeFn setResultRange) { |
323 | setResultRange(getResult(), inferShrS(argRanges)); |
324 | } |
325 | |