1 | //===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // A collection of "safe" arithmetic helper methods. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
14 | #define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
15 | |
16 | #include <cassert> |
17 | #include <cinttypes> |
18 | #include <limits> |
19 | #include <type_traits> |
20 | |
21 | namespace mlir { |
22 | namespace sparse_tensor { |
23 | namespace detail { |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // |
27 | // Safe comparison functions. |
28 | // |
29 | // Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which |
30 | // are careful to ensure that negatives are always considered strictly |
31 | // less than non-negatives regardless of the signedness of the types of |
32 | // the two arguments. They are "safe" in that they guarantee to *always* |
33 | // give an output and that that output is correct; in particular this means |
34 | // they never use assertions or other mechanisms for "returning an error". |
35 | // |
36 | // These functions are C++17-compatible backports of the safe comparison |
37 | // functions added in C++20, and the implementations are based on the |
38 | // sample implementations provided by the standard: |
39 | // <https://en.cppreference.com/w/cpp/utility/intcmp>. |
40 | // |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | template <typename T, typename U> |
44 | constexpr bool safelyEQ(T t, U u) noexcept { |
45 | using UT = std::make_unsigned_t<T>; |
46 | using UU = std::make_unsigned_t<U>; |
47 | if constexpr (std::is_signed_v<T> == std::is_signed_v<U>) |
48 | return t == u; |
49 | else if constexpr (std::is_signed_v<T>) |
50 | return t < 0 ? false : static_cast<UT>(t) == u; |
51 | else |
52 | return u < 0 ? false : t == static_cast<UU>(u); |
53 | } |
54 | |
55 | template <typename T, typename U> |
56 | constexpr bool safelyNE(T t, U u) noexcept { |
57 | return !safelyEQ(t, u); |
58 | } |
59 | |
60 | template <typename T, typename U> |
61 | constexpr bool safelyLT(T t, U u) noexcept { |
62 | using UT = std::make_unsigned_t<T>; |
63 | using UU = std::make_unsigned_t<U>; |
64 | if constexpr (std::is_signed_v<T> == std::is_signed_v<U>) |
65 | return t < u; |
66 | else if constexpr (std::is_signed_v<T>) |
67 | return t < 0 ? true : static_cast<UT>(t) < u; |
68 | else |
69 | return u < 0 ? false : t < static_cast<UU>(u); |
70 | } |
71 | |
72 | template <typename T, typename U> |
73 | constexpr bool safelyGT(T t, U u) noexcept { |
74 | return safelyLT(u, t); |
75 | } |
76 | |
77 | template <typename T, typename U> |
78 | constexpr bool safelyLE(T t, U u) noexcept { |
79 | return !safelyGT(t, u); |
80 | } |
81 | |
82 | template <typename T, typename U> |
83 | constexpr bool safelyGE(T t, U u) noexcept { |
84 | return !safelyLT(t, u); |
85 | } |
86 | |
87 | //===----------------------------------------------------------------------===// |
88 | // |
89 | // Overflow checking functions. |
90 | // |
91 | // These functions use assertions to ensure correctness with respect to |
92 | // overflow/underflow. Unlike the "safe" functions above, these "checked" |
93 | // functions only guarantee that *if* they return an answer then that answer |
94 | // is correct. When assertions are enabled, they do their best to remain |
95 | // as fast as possible (since MLIR keeps assertions enabled by default, |
96 | // even for optimized builds). When assertions are disabled, they use the |
97 | // standard unchecked implementations. |
98 | // |
99 | //===----------------------------------------------------------------------===// |
100 | |
101 | /// A version of `static_cast<To>` which checks for overflow/underflow. |
102 | /// The implementation avoids performing runtime assertions whenever |
103 | /// the types alone are sufficient to statically prove that overflow |
104 | /// cannot happen. |
105 | template <typename To, typename From> |
106 | [[nodiscard]] inline To checkOverflowCast(From x) { |
107 | // Check the lower bound. (For when casting from signed types.) |
108 | constexpr To minTo = std::numeric_limits<To>::min(); |
109 | constexpr From minFrom = std::numeric_limits<From>::min(); |
110 | if constexpr (!safelyGE(minFrom, minTo)) |
111 | assert(safelyGE(x, minTo) && "cast would underflow" ); |
112 | // Check the upper bound. |
113 | constexpr To maxTo = std::numeric_limits<To>::max(); |
114 | constexpr From maxFrom = std::numeric_limits<From>::max(); |
115 | if constexpr (!safelyLE(maxFrom, maxTo)) |
116 | assert(safelyLE(x, maxTo) && "cast would overflow" ); |
117 | // Now do the cast itself. |
118 | return static_cast<To>(x); |
119 | } |
120 | |
121 | /// A version of `operator*` on `uint64_t` which guards against overflows |
122 | /// (when assertions are enabled). |
123 | inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { |
124 | assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) && |
125 | "Integer overflow" ); |
126 | return lhs * rhs; |
127 | } |
128 | |
129 | } // namespace detail |
130 | } // namespace sparse_tensor |
131 | } // namespace mlir |
132 | |
133 | #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
134 | |