1//===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/Presburger/SlowMPInt.h"
10#include "mlir/Support/LLVM.h"
11#include "llvm/ADT/APInt.h"
12#include "llvm/ADT/Hashing.h"
13#include "llvm/ADT/STLFunctionalExtras.h"
14#include "llvm/Support/raw_ostream.h"
15#include <algorithm>
16#include <cassert>
17#include <cstdint>
18#include <functional>
19
20using namespace mlir;
21using namespace presburger;
22using namespace detail;
23
24SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
25SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
26SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
27SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
28SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
29
30llvm::hash_code detail::hash_value(const SlowMPInt &x) {
31 return hash_value(Arg: x.val);
32}
33
34/// ---------------------------------------------------------------------------
35/// Printing.
36/// ---------------------------------------------------------------------------
37void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
38
39void SlowMPInt::dump() const { print(os&: llvm::errs()); }
40
41llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
42 const SlowMPInt &x) {
43 x.print(os);
44 return os;
45}
46
47/// ---------------------------------------------------------------------------
48/// Convenience operator overloads for int64_t.
49/// ---------------------------------------------------------------------------
50SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
51 return a += SlowMPInt(b);
52}
53SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
54 return a -= SlowMPInt(b);
55}
56SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
57 return a *= SlowMPInt(b);
58}
59SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
60 return a /= SlowMPInt(b);
61}
62SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
63 return a %= SlowMPInt(b);
64}
65
66bool detail::operator==(const SlowMPInt &a, int64_t b) {
67 return a == SlowMPInt(b);
68}
69bool detail::operator!=(const SlowMPInt &a, int64_t b) {
70 return a != SlowMPInt(b);
71}
72bool detail::operator>(const SlowMPInt &a, int64_t b) {
73 return a > SlowMPInt(b);
74}
75bool detail::operator<(const SlowMPInt &a, int64_t b) {
76 return a < SlowMPInt(b);
77}
78bool detail::operator<=(const SlowMPInt &a, int64_t b) {
79 return a <= SlowMPInt(b);
80}
81bool detail::operator>=(const SlowMPInt &a, int64_t b) {
82 return a >= SlowMPInt(b);
83}
84SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
85 return a + SlowMPInt(b);
86}
87SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
88 return a - SlowMPInt(b);
89}
90SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
91 return a * SlowMPInt(b);
92}
93SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
94 return a / SlowMPInt(b);
95}
96SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
97 return a % SlowMPInt(b);
98}
99
100bool detail::operator==(int64_t a, const SlowMPInt &b) {
101 return SlowMPInt(a) == b;
102}
103bool detail::operator!=(int64_t a, const SlowMPInt &b) {
104 return SlowMPInt(a) != b;
105}
106bool detail::operator>(int64_t a, const SlowMPInt &b) {
107 return SlowMPInt(a) > b;
108}
109bool detail::operator<(int64_t a, const SlowMPInt &b) {
110 return SlowMPInt(a) < b;
111}
112bool detail::operator<=(int64_t a, const SlowMPInt &b) {
113 return SlowMPInt(a) <= b;
114}
115bool detail::operator>=(int64_t a, const SlowMPInt &b) {
116 return SlowMPInt(a) >= b;
117}
118SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
119 return SlowMPInt(a) + b;
120}
121SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
122 return SlowMPInt(a) - b;
123}
124SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
125 return SlowMPInt(a) * b;
126}
127SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
128 return SlowMPInt(a) / b;
129}
130SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
131 return SlowMPInt(a) % b;
132}
133
134static unsigned getMaxWidth(const APInt &a, const APInt &b) {
135 return std::max(a: a.getBitWidth(), b: b.getBitWidth());
136}
137
138/// ---------------------------------------------------------------------------
139/// Comparison operators.
140/// ---------------------------------------------------------------------------
141
142// TODO: consider instead making APInt::compare available and using that.
143bool SlowMPInt::operator==(const SlowMPInt &o) const {
144 unsigned width = getMaxWidth(a: val, b: o.val);
145 return val.sext(width) == o.val.sext(width);
146}
147bool SlowMPInt::operator!=(const SlowMPInt &o) const {
148 unsigned width = getMaxWidth(a: val, b: o.val);
149 return val.sext(width) != o.val.sext(width);
150}
151bool SlowMPInt::operator>(const SlowMPInt &o) const {
152 unsigned width = getMaxWidth(a: val, b: o.val);
153 return val.sext(width).sgt(RHS: o.val.sext(width));
154}
155bool SlowMPInt::operator<(const SlowMPInt &o) const {
156 unsigned width = getMaxWidth(a: val, b: o.val);
157 return val.sext(width).slt(RHS: o.val.sext(width));
158}
159bool SlowMPInt::operator<=(const SlowMPInt &o) const {
160 unsigned width = getMaxWidth(a: val, b: o.val);
161 return val.sext(width).sle(RHS: o.val.sext(width));
162}
163bool SlowMPInt::operator>=(const SlowMPInt &o) const {
164 unsigned width = getMaxWidth(a: val, b: o.val);
165 return val.sext(width).sge(RHS: o.val.sext(width));
166}
167
168/// ---------------------------------------------------------------------------
169/// Arithmetic operators.
170/// ---------------------------------------------------------------------------
171
172/// Bring a and b to have the same width and then call op(a, b, overflow).
173/// If the overflow bit becomes set, resize a and b to double the width and
174/// call op(a, b, overflow), returning its result. The operation with double
175/// widths should not also overflow.
176APInt runOpWithExpandOnOverflow(
177 const APInt &a, const APInt &b,
178 llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
179 op) {
180 bool overflow;
181 unsigned width = getMaxWidth(a, b);
182 APInt ret = op(a.sext(width), b.sext(width), overflow);
183 if (!overflow)
184 return ret;
185
186 width *= 2;
187 ret = op(a.sext(width), b.sext(width), overflow);
188 assert(!overflow && "double width should be sufficient to avoid overflow!");
189 return ret;
190}
191
192SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
193 return SlowMPInt(
194 runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::sadd_ov)));
195}
196SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
197 return SlowMPInt(
198 runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::ssub_ov)));
199}
200SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
201 return SlowMPInt(
202 runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::smul_ov)));
203}
204SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
205 return SlowMPInt(
206 runOpWithExpandOnOverflow(a: val, b: o.val, op: std::mem_fn(pm: &APInt::sdiv_ov)));
207}
208SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
209SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
210 if (rhs == -1)
211 return -lhs;
212 unsigned width = getMaxWidth(a: lhs.val, b: rhs.val);
213 return SlowMPInt(llvm::APIntOps::RoundingSDiv(
214 A: lhs.val.sext(width), B: rhs.val.sext(width), RM: APInt::Rounding::UP));
215}
216SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
217 if (rhs == -1)
218 return -lhs;
219 unsigned width = getMaxWidth(a: lhs.val, b: rhs.val);
220 return SlowMPInt(llvm::APIntOps::RoundingSDiv(
221 A: lhs.val.sext(width), B: rhs.val.sext(width), RM: APInt::Rounding::DOWN));
222}
223// The RHS is always expected to be positive, and the result
224/// is always non-negative.
225SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
226 assert(rhs >= 1 && "mod is only supported for positive divisors!");
227 return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
228}
229
230SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
231 assert(a >= 0 && b >= 0 && "operands must be non-negative!");
232 unsigned width = getMaxWidth(a: a.val, b: b.val);
233 return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(A: a.val.sext(width),
234 B: b.val.sext(width)));
235}
236
237/// Returns the least common multiple of 'a' and 'b'.
238SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
239 SlowMPInt x = abs(x: a);
240 SlowMPInt y = abs(x: b);
241 return (x * y) / gcd(a: x, b: y);
242}
243
244/// This operation cannot overflow.
245SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
246 unsigned width = std::max(a: val.getBitWidth(), b: o.val.getBitWidth());
247 return SlowMPInt(val.sext(width).srem(RHS: o.val.sext(width)));
248}
249
250SlowMPInt SlowMPInt::operator-() const {
251 if (val.isMinSignedValue()) {
252 /// Overflow only occurs when the value is the minimum possible value.
253 APInt ret = val.sext(width: 2 * val.getBitWidth());
254 return SlowMPInt(-ret);
255 }
256 return SlowMPInt(-val);
257}
258
259/// ---------------------------------------------------------------------------
260/// Assignment operators, preincrement, predecrement.
261/// ---------------------------------------------------------------------------
262SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
263 *this = *this + o;
264 return *this;
265}
266SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
267 *this = *this - o;
268 return *this;
269}
270SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
271 *this = *this * o;
272 return *this;
273}
274SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
275 *this = *this / o;
276 return *this;
277}
278SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
279 *this = *this % o;
280 return *this;
281}
282SlowMPInt &SlowMPInt::operator++() {
283 *this += 1;
284 return *this;
285}
286
287SlowMPInt &SlowMPInt::operator--() {
288 *this -= 1;
289 return *this;
290}
291

source code of mlir/lib/Analysis/Presburger/SlowMPInt.cpp