1//===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A collection of "safe" arithmetic helper methods.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
14#define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
15
16#include <cassert>
17#include <cinttypes>
18#include <limits>
19#include <type_traits>
20
21namespace mlir {
22namespace sparse_tensor {
23namespace detail {
24
25//===----------------------------------------------------------------------===//
26//
27// Safe comparison functions.
28//
29// Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which
30// are careful to ensure that negatives are always considered strictly
31// less than non-negatives regardless of the signedness of the types of
32// the two arguments. They are "safe" in that they guarantee to *always*
33// give an output and that that output is correct; in particular this means
34// they never use assertions or other mechanisms for "returning an error".
35//
36// These functions are C++17-compatible backports of the safe comparison
37// functions added in C++20, and the implementations are based on the
38// sample implementations provided by the standard:
39// <https://en.cppreference.com/w/cpp/utility/intcmp>.
40//
41//===----------------------------------------------------------------------===//
42
43template <typename T, typename U>
44constexpr bool safelyEQ(T t, U u) noexcept {
45 using UT = std::make_unsigned_t<T>;
46 using UU = std::make_unsigned_t<U>;
47 if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
48 return t == u;
49 else if constexpr (std::is_signed_v<T>)
50 return t < 0 ? false : static_cast<UT>(t) == u;
51 else
52 return u < 0 ? false : t == static_cast<UU>(u);
53}
54
55template <typename T, typename U>
56constexpr bool safelyNE(T t, U u) noexcept {
57 return !safelyEQ(t, u);
58}
59
60template <typename T, typename U>
61constexpr bool safelyLT(T t, U u) noexcept {
62 using UT = std::make_unsigned_t<T>;
63 using UU = std::make_unsigned_t<U>;
64 if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
65 return t < u;
66 else if constexpr (std::is_signed_v<T>)
67 return t < 0 ? true : static_cast<UT>(t) < u;
68 else
69 return u < 0 ? false : t < static_cast<UU>(u);
70}
71
72template <typename T, typename U>
73constexpr bool safelyGT(T t, U u) noexcept {
74 return safelyLT(u, t);
75}
76
77template <typename T, typename U>
78constexpr bool safelyLE(T t, U u) noexcept {
79 return !safelyGT(t, u);
80}
81
82template <typename T, typename U>
83constexpr bool safelyGE(T t, U u) noexcept {
84 return !safelyLT(t, u);
85}
86
87//===----------------------------------------------------------------------===//
88//
89// Overflow checking functions.
90//
91// These functions use assertions to ensure correctness with respect to
92// overflow/underflow. Unlike the "safe" functions above, these "checked"
93// functions only guarantee that *if* they return an answer then that answer
94// is correct. When assertions are enabled, they do their best to remain
95// as fast as possible (since MLIR keeps assertions enabled by default,
96// even for optimized builds). When assertions are disabled, they use the
97// standard unchecked implementations.
98//
99//===----------------------------------------------------------------------===//
100
101/// A version of `static_cast<To>` which checks for overflow/underflow.
102/// The implementation avoids performing runtime assertions whenever
103/// the types alone are sufficient to statically prove that overflow
104/// cannot happen.
105template <typename To, typename From>
106[[nodiscard]] inline To checkOverflowCast(From x) {
107 // Check the lower bound. (For when casting from signed types.)
108 constexpr To minTo = std::numeric_limits<To>::min();
109 constexpr From minFrom = std::numeric_limits<From>::min();
110 if constexpr (!safelyGE(minFrom, minTo))
111 assert(safelyGE(x, minTo) && "cast would underflow");
112 // Check the upper bound.
113 constexpr To maxTo = std::numeric_limits<To>::max();
114 constexpr From maxFrom = std::numeric_limits<From>::max();
115 if constexpr (!safelyLE(maxFrom, maxTo))
116 assert(safelyLE(x, maxTo) && "cast would overflow");
117 // Now do the cast itself.
118 return static_cast<To>(x);
119}
120
121/// A version of `operator*` on `uint64_t` which guards against overflows
122/// (when assertions are enabled).
123inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
124 assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
125 "Integer overflow");
126 return lhs * rhs;
127}
128
129} // namespace detail
130} // namespace sparse_tensor
131} // namespace mlir
132
133#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
134

source code of mlir/include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h