1 | //===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/Presburger/SlowMPInt.h" |
10 | #include "mlir/Support/LLVM.h" |
11 | #include "llvm/ADT/APInt.h" |
12 | #include "llvm/ADT/Hashing.h" |
13 | #include "llvm/ADT/STLFunctionalExtras.h" |
14 | #include "llvm/Support/raw_ostream.h" |
15 | #include <algorithm> |
16 | #include <cassert> |
17 | #include <cstdint> |
18 | #include <functional> |
19 | |
20 | using namespace mlir; |
21 | using namespace presburger; |
22 | using namespace detail; |
23 | |
24 | SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {} |
25 | SlowMPInt::SlowMPInt() : SlowMPInt(0) {} |
26 | SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {} |
27 | SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); } |
28 | SlowMPInt::operator int64_t() const { return val.getSExtValue(); } |
29 | |
30 | llvm::hash_code detail::hash_value(const SlowMPInt &x) { |
31 | return hash_value(Arg: x.val); |
32 | } |
33 | |
34 | /// --------------------------------------------------------------------------- |
35 | /// Printing. |
36 | /// --------------------------------------------------------------------------- |
37 | void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; } |
38 | |
39 | void SlowMPInt::dump() const { print(os&: llvm::errs()); } |
40 | |
41 | llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os, |
42 | const SlowMPInt &x) { |
43 | x.print(os); |
44 | return os; |
45 | } |
46 | |
47 | /// --------------------------------------------------------------------------- |
48 | /// Convenience operator overloads for int64_t. |
49 | /// --------------------------------------------------------------------------- |
50 | SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) { |
51 | return a += SlowMPInt(b); |
52 | } |
53 | SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) { |
54 | return a -= SlowMPInt(b); |
55 | } |
56 | SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) { |
57 | return a *= SlowMPInt(b); |
58 | } |
59 | SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) { |
60 | return a /= SlowMPInt(b); |
61 | } |
62 | SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) { |
63 | return a %= SlowMPInt(b); |
64 | } |
65 | |
66 | bool detail::operator==(const SlowMPInt &a, int64_t b) { |
67 | return a == SlowMPInt(b); |
68 | } |
69 | bool detail::operator!=(const SlowMPInt &a, int64_t b) { |
70 | return a != SlowMPInt(b); |
71 | } |
72 | bool detail::operator>(const SlowMPInt &a, int64_t b) { |
73 | return a > SlowMPInt(b); |
74 | } |
75 | bool detail::operator<(const SlowMPInt &a, int64_t b) { |
76 | return a < SlowMPInt(b); |
77 | } |
78 | bool detail::operator<=(const SlowMPInt &a, int64_t b) { |
79 | return a <= SlowMPInt(b); |
80 | } |
81 | bool detail::operator>=(const SlowMPInt &a, int64_t b) { |
82 | return a >= SlowMPInt(b); |
83 | } |
84 | SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) { |
85 | return a + SlowMPInt(b); |
86 | } |
87 | SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) { |
88 | return a - SlowMPInt(b); |
89 | } |
90 | SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) { |
91 | return a * SlowMPInt(b); |
92 | } |
93 | SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) { |
94 | return a / SlowMPInt(b); |
95 | } |
96 | SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) { |
97 | return a % SlowMPInt(b); |
98 | } |
99 | |
100 | bool detail::operator==(int64_t a, const SlowMPInt &b) { |
101 | return SlowMPInt(a) == b; |
102 | } |
103 | bool detail::operator!=(int64_t a, const SlowMPInt &b) { |
104 | return SlowMPInt(a) != b; |
105 | } |
106 | bool detail::operator>(int64_t a, const SlowMPInt &b) { |
107 | return SlowMPInt(a) > b; |
108 | } |
109 | bool detail::operator<(int64_t a, const SlowMPInt &b) { |
110 | return SlowMPInt(a) < b; |
111 | } |
112 | bool detail::operator<=(int64_t a, const SlowMPInt &b) { |
113 | return SlowMPInt(a) <= b; |
114 | } |
115 | bool detail::operator>=(int64_t a, const SlowMPInt &b) { |
116 | return SlowMPInt(a) >= b; |
117 | } |
118 | SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) { |
119 | return SlowMPInt(a) + b; |
120 | } |
121 | SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) { |
122 | return SlowMPInt(a) - b; |
123 | } |
124 | SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) { |
125 | return SlowMPInt(a) * b; |
126 | } |
127 | SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) { |
128 | return SlowMPInt(a) / b; |
129 | } |
130 | SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) { |
131 | return SlowMPInt(a) % b; |
132 | } |
133 | |
134 | static unsigned getMaxWidth(const APInt &a, const APInt &b) { |
135 | return std::max(a: a.getBitWidth(), b: b.getBitWidth()); |
136 | } |
137 | |
138 | /// --------------------------------------------------------------------------- |
139 | /// Comparison operators. |
140 | /// --------------------------------------------------------------------------- |
141 | |
142 | // TODO: consider instead making APInt::compare available and using that. |
143 | bool SlowMPInt::operator==(const SlowMPInt &o) const { |
144 | unsigned width = getMaxWidth(a: val, b: o.val); |
145 | return val.sext(width) == o.val.sext(width); |
146 | } |
147 | bool SlowMPInt::operator!=(const SlowMPInt &o) const { |
148 | unsigned width = getMaxWidth(a: val, b: o.val); |
149 | return val.sext(width) != o.val.sext(width); |
150 | } |
151 | bool SlowMPInt::operator>(const SlowMPInt &o) const { |
152 | unsigned width = getMaxWidth(a: val, b: o.val); |
153 | return val.sext(width).sgt(RHS: o.val.sext(width)); |
154 | } |
155 | bool SlowMPInt::operator<(const SlowMPInt &o) const { |
156 | unsigned width = getMaxWidth(a: val, b: o.val); |
157 | return val.sext(width).slt(RHS: o.val.sext(width)); |
158 | } |
159 | bool SlowMPInt::operator<=(const SlowMPInt &o) const { |
160 | unsigned width = getMaxWidth(a: val, b: o.val); |
161 | return val.sext(width).sle(RHS: o.val.sext(width)); |
162 | } |
163 | bool SlowMPInt::operator>=(const SlowMPInt &o) const { |
164 | unsigned width = getMaxWidth(a: val, b: o.val); |
165 | return val.sext(width).sge(RHS: o.val.sext(width)); |
166 | } |
167 | |
168 | /// --------------------------------------------------------------------------- |
169 | /// Arithmetic operators. |
170 | /// --------------------------------------------------------------------------- |
171 | |
172 | /// Bring a and b to have the same width and then call op(a, b, overflow). |
173 | /// If the overflow bit becomes set, resize a and b to double the width and |
174 | /// call op(a, b, overflow), returning its result. The operation with double |
175 | /// widths should not also overflow. |
176 | APInt runOpWithExpandOnOverflow( |
177 | const APInt &a, const APInt &b, |
178 | llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)> |
179 | op) { |
180 | bool overflow; |
181 | unsigned width = getMaxWidth(a, b); |
182 | APInt ret = op(a.sext(width), b.sext(width), overflow); |
183 | if (!overflow) |
184 | return ret; |
185 | |
186 | width *= 2; |
187 | ret = op(a.sext(width), b.sext(width), overflow); |
188 | assert(!overflow && "double width should be sufficient to avoid overflow!" ); |
189 | return ret; |
190 | } |
191 | |
192 | SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const { |
193 | return SlowMPInt( |
194 | runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::sadd_ov))); |
195 | } |
196 | SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const { |
197 | return SlowMPInt( |
198 | runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::ssub_ov))); |
199 | } |
200 | SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const { |
201 | return SlowMPInt( |
202 | runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::smul_ov))); |
203 | } |
204 | SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const { |
205 | return SlowMPInt( |
206 | runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::sdiv_ov))); |
207 | } |
208 | SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; } |
209 | SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { |
210 | if (rhs == -1) |
211 | return -lhs; |
212 | unsigned width = getMaxWidth(a: lhs.val, b: rhs.val); |
213 | return SlowMPInt(llvm::APIntOps::RoundingSDiv( |
214 | A: lhs.val.sext(width), B: rhs.val.sext(width), RM: APInt::Rounding::UP)); |
215 | } |
216 | SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { |
217 | if (rhs == -1) |
218 | return -lhs; |
219 | unsigned width = getMaxWidth(a: lhs.val, b: rhs.val); |
220 | return SlowMPInt(llvm::APIntOps::RoundingSDiv( |
221 | A: lhs.val.sext(width), B: rhs.val.sext(width), RM: APInt::Rounding::DOWN)); |
222 | } |
223 | // The RHS is always expected to be positive, and the result |
224 | /// is always non-negative. |
225 | SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) { |
226 | assert(rhs >= 1 && "mod is only supported for positive divisors!" ); |
227 | return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; |
228 | } |
229 | |
230 | SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) { |
231 | assert(a >= 0 && b >= 0 && "operands must be non-negative!" ); |
232 | unsigned width = getMaxWidth(a: a.val, b: b.val); |
233 | return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(A: a.val.sext(width), |
234 | B: b.val.sext(width))); |
235 | } |
236 | |
237 | /// Returns the least common multiple of 'a' and 'b'. |
238 | SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) { |
239 | SlowMPInt x = abs(x: a); |
240 | SlowMPInt y = abs(x: b); |
241 | return (x * y) / gcd(a: x, b: y); |
242 | } |
243 | |
244 | /// This operation cannot overflow. |
245 | SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const { |
246 | unsigned width = std::max(a: val.getBitWidth(), b: o.val.getBitWidth()); |
247 | return SlowMPInt(val.sext(width).srem(RHS: o.val.sext(width))); |
248 | } |
249 | |
250 | SlowMPInt SlowMPInt::operator-() const { |
251 | if (val.isMinSignedValue()) { |
252 | /// Overflow only occurs when the value is the minimum possible value. |
253 | APInt ret = val.sext(width: 2 * val.getBitWidth()); |
254 | return SlowMPInt(-ret); |
255 | } |
256 | return SlowMPInt(-val); |
257 | } |
258 | |
259 | /// --------------------------------------------------------------------------- |
260 | /// Assignment operators, preincrement, predecrement. |
261 | /// --------------------------------------------------------------------------- |
262 | SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) { |
263 | *this = *this + o; |
264 | return *this; |
265 | } |
266 | SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) { |
267 | *this = *this - o; |
268 | return *this; |
269 | } |
270 | SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) { |
271 | *this = *this * o; |
272 | return *this; |
273 | } |
274 | SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) { |
275 | *this = *this / o; |
276 | return *this; |
277 | } |
278 | SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) { |
279 | *this = *this % o; |
280 | return *this; |
281 | } |
282 | SlowMPInt &SlowMPInt::operator++() { |
283 | *this += 1; |
284 | return *this; |
285 | } |
286 | |
287 | SlowMPInt &SlowMPInt::operator--() { |
288 | *this -= 1; |
289 | return *this; |
290 | } |
291 | |