1 | //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
10 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
11 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
12 | |
13 | #include "llvm/Support/Debug.h" |
14 | #include <optional> |
15 | |
16 | #define DEBUG_TYPE "int-range-analysis" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::index; |
20 | using namespace mlir::intrange; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Constants |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
27 | SetIntRangeFn setResultRange) { |
28 | const APInt &value = getValue(); |
29 | setResultRange(getResult(), ConstantIntRanges::constant(value)); |
30 | } |
31 | |
32 | void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
33 | SetIntRangeFn setResultRange) { |
34 | bool value = getValue(); |
35 | APInt asInt(/*numBits=*/1, value); |
36 | setResultRange(getResult(), ConstantIntRanges::constant(asInt)); |
37 | } |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // Arithmec operations. All of these operations will have their results inferred |
41 | // using both the 64-bit values and truncated 32-bit values of their inputs, |
42 | // with the results being the union of those inferences, except where the |
43 | // truncation of the 64-bit result is equal to the 32-bit result (at which time |
44 | // we take the 64-bit result). |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
48 | SetIntRangeFn setResultRange) { |
49 | setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); |
50 | } |
51 | |
52 | void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
53 | SetIntRangeFn setResultRange) { |
54 | setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); |
55 | } |
56 | |
57 | void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
58 | SetIntRangeFn setResultRange) { |
59 | setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); |
60 | } |
61 | |
62 | void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
63 | SetIntRangeFn setResultRange) { |
64 | setResultRange(getResult(), |
65 | inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); |
66 | } |
67 | |
68 | void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
69 | SetIntRangeFn setResultRange) { |
70 | setResultRange(getResult(), |
71 | inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); |
72 | } |
73 | |
74 | void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
75 | SetIntRangeFn setResultRange) { |
76 | setResultRange(getResult(), |
77 | inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); |
78 | } |
79 | |
80 | void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
81 | SetIntRangeFn setResultRange) { |
82 | setResultRange(getResult(), |
83 | inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); |
84 | } |
85 | |
86 | void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
87 | SetIntRangeFn setResultRange) { |
88 | return setResultRange( |
89 | getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); |
90 | } |
91 | |
92 | void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
93 | SetIntRangeFn setResultRange) { |
94 | setResultRange(getResult(), |
95 | inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); |
96 | } |
97 | |
98 | void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
99 | SetIntRangeFn setResultRange) { |
100 | setResultRange(getResult(), |
101 | inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); |
102 | } |
103 | |
104 | void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
105 | SetIntRangeFn setResultRange) { |
106 | setResultRange(getResult(), |
107 | inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); |
108 | } |
109 | |
110 | void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
111 | SetIntRangeFn setResultRange) { |
112 | setResultRange(getResult(), |
113 | inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); |
114 | } |
115 | |
116 | void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
117 | SetIntRangeFn setResultRange) { |
118 | setResultRange(getResult(), |
119 | inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); |
120 | } |
121 | |
122 | void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
123 | SetIntRangeFn setResultRange) { |
124 | setResultRange(getResult(), |
125 | inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); |
126 | } |
127 | |
128 | void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
129 | SetIntRangeFn setResultRange) { |
130 | setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); |
131 | } |
132 | |
133 | void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
134 | SetIntRangeFn setResultRange) { |
135 | setResultRange(getResult(), |
136 | inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); |
137 | } |
138 | |
139 | void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
140 | SetIntRangeFn setResultRange) { |
141 | setResultRange(getResult(), |
142 | inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); |
143 | } |
144 | |
145 | void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
146 | SetIntRangeFn setResultRange) { |
147 | setResultRange(getResult(), |
148 | inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); |
149 | } |
150 | |
151 | void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
152 | SetIntRangeFn setResultRange) { |
153 | setResultRange(getResult(), |
154 | inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); |
155 | } |
156 | |
157 | void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
158 | SetIntRangeFn setResultRange) { |
159 | setResultRange(getResult(), |
160 | inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); |
161 | } |
162 | |
163 | //===----------------------------------------------------------------------===// |
164 | // Casts |
165 | //===----------------------------------------------------------------------===// |
166 | |
167 | static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, |
168 | unsigned srcWidth, unsigned destWidth, |
169 | bool isSigned) { |
170 | if (srcWidth < destWidth) |
171 | return isSigned ? extSIRange(range, destWidth) |
172 | : extUIRange(range, destWidth); |
173 | if (srcWidth > destWidth) |
174 | return truncRange(range, destWidth); |
175 | return range; |
176 | } |
177 | |
178 | // When casting to `index`, we will take the union of the possible fixed-width |
179 | // casts. |
180 | static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, |
181 | Type sourceType, Type destType, |
182 | bool isSigned) { |
183 | unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(type: sourceType); |
184 | unsigned destWidth = ConstantIntRanges::getStorageBitwidth(type: destType); |
185 | if (sourceType.isIndex()) |
186 | return makeLikeDest(range, srcWidth, destWidth, isSigned); |
187 | // We are casting to indexs, so use the union of the 32-bit and 64-bit casts |
188 | ConstantIntRanges storageRange = |
189 | makeLikeDest(range, srcWidth, destWidth, isSigned); |
190 | ConstantIntRanges minWidthRange = |
191 | makeLikeDest(range, srcWidth, destWidth: indexMinWidth, isSigned); |
192 | ConstantIntRanges minWidthExt = extRange(range: minWidthRange, destWidth); |
193 | ConstantIntRanges ret = storageRange.rangeUnion(other: minWidthExt); |
194 | return ret; |
195 | } |
196 | |
197 | void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
198 | SetIntRangeFn setResultRange) { |
199 | Type sourceType = getOperand().getType(); |
200 | Type destType = getResult().getType(); |
201 | setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, |
202 | /*isSigned=*/true)); |
203 | } |
204 | |
205 | void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
206 | SetIntRangeFn setResultRange) { |
207 | Type sourceType = getOperand().getType(); |
208 | Type destType = getResult().getType(); |
209 | setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, |
210 | /*isSigned=*/false)); |
211 | } |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // CmpOp |
215 | //===----------------------------------------------------------------------===// |
216 | |
217 | void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
218 | SetIntRangeFn setResultRange) { |
219 | index::IndexCmpPredicate indexPred = getPred(); |
220 | intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred); |
221 | const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
222 | |
223 | APInt min = APInt::getZero(1); |
224 | APInt max = APInt::getAllOnes(1); |
225 | |
226 | std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs); |
227 | |
228 | ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), |
229 | rhsTrunc = truncRange(rhs, indexMinWidth); |
230 | std::optional<bool> truthValue32 = |
231 | intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); |
232 | |
233 | if (truthValue64 == truthValue32) { |
234 | if (truthValue64.has_value() && *truthValue64) |
235 | min = max; |
236 | else if (truthValue64.has_value() && !(*truthValue64)) |
237 | max = min; |
238 | } |
239 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); |
240 | } |
241 | |
242 | //===----------------------------------------------------------------------===// |
243 | // SizeOf, which is bounded between the two supported bitwidth (32 and 64). |
244 | //===----------------------------------------------------------------------===// |
245 | |
246 | void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
247 | SetIntRangeFn setResultRange) { |
248 | unsigned storageWidth = |
249 | ConstantIntRanges::getStorageBitwidth(getResult().getType()); |
250 | APInt min(/*numBits=*/storageWidth, indexMinWidth); |
251 | APInt max(/*numBits=*/storageWidth, indexMaxWidth); |
252 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); |
253 | } |
254 | |