| 1 | //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===// | 
| 2 | // | 
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
| 6 | // | 
| 7 | //===----------------------------------------------------------------------===// | 
| 8 | /// \file | 
| 9 | /// This file provides a collection of function (or more generally, callable) | 
| 10 | /// type erasure utilities supplementing those provided by the standard library | 
| 11 | /// in `<function>`. | 
| 12 | /// | 
| 13 | /// It provides `unique_function`, which works like `std::function` but supports | 
| 14 | /// move-only callable objects and const-qualification. | 
| 15 | /// | 
| 16 | /// Future plans: | 
| 17 | /// - Add a `function` that provides ref-qualified support, which doesn't work | 
| 18 | ///   with `std::function`. | 
| 19 | /// - Provide support for specifying multiple signatures to type erase callable | 
| 20 | ///   objects with an overload set, such as those produced by generic lambdas. | 
| 21 | /// - Expand to include a copyable utility that directly replaces std::function | 
| 22 | ///   but brings the above improvements. | 
| 23 | /// | 
| 24 | /// Note that LLVM's utilities are greatly simplified by not supporting | 
| 25 | /// allocators. | 
| 26 | /// | 
| 27 | /// If the standard library ever begins to provide comparable facilities we can | 
| 28 | /// consider switching to those. | 
| 29 | /// | 
| 30 | //===----------------------------------------------------------------------===// | 
| 31 |  | 
| 32 | #ifndef LLVM_ADT_FUNCTIONEXTRAS_H | 
| 33 | #define  | 
| 34 |  | 
| 35 | #include "llvm/ADT/PointerIntPair.h" | 
| 36 | #include "llvm/ADT/PointerUnion.h" | 
| 37 | #include "llvm/ADT/STLForwardCompat.h" | 
| 38 | #include "llvm/Support/MemAlloc.h" | 
| 39 | #include "llvm/Support/type_traits.h" | 
| 40 | #include <cstring> | 
| 41 | #include <memory> | 
| 42 | #include <type_traits> | 
| 43 |  | 
| 44 | namespace llvm { | 
| 45 |  | 
| 46 | /// unique_function is a type-erasing functor similar to std::function. | 
| 47 | /// | 
| 48 | /// It can hold move-only function objects, like lambdas capturing unique_ptrs. | 
| 49 | /// Accordingly, it is movable but not copyable. | 
| 50 | /// | 
| 51 | /// It supports const-qualification: | 
| 52 | /// - unique_function<int() const> has a const operator(). | 
| 53 | ///   It can only hold functions which themselves have a const operator(). | 
| 54 | /// - unique_function<int()> has a non-const operator(). | 
| 55 | ///   It can hold functions with a non-const operator(), like mutable lambdas. | 
| 56 | template <typename FunctionT> class unique_function; | 
| 57 |  | 
| 58 | namespace detail { | 
| 59 |  | 
| 60 | template <typename T> | 
| 61 | using EnableIfTrivial = | 
| 62 |     std::enable_if_t<std::is_trivially_move_constructible<T>::value && | 
| 63 |                      std::is_trivially_destructible<T>::value>; | 
| 64 | template <typename CallableT, typename ThisT> | 
| 65 | using EnableUnlessSameType = | 
| 66 |     std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>; | 
| 67 | template <typename CallableT, typename Ret, typename... Params> | 
| 68 | using EnableIfCallable = std::enable_if_t<std::disjunction< | 
| 69 |     std::is_void<Ret>, | 
| 70 |     std::is_same<decltype(std::declval<CallableT>()(std::declval<Params>()...)), | 
| 71 |                  Ret>, | 
| 72 |     std::is_same<const decltype(std::declval<CallableT>()( | 
| 73 |                      std::declval<Params>()...)), | 
| 74 |                  Ret>, | 
| 75 |     std::is_convertible<decltype(std::declval<CallableT>()( | 
| 76 |                             std::declval<Params>()...)), | 
| 77 |                         Ret>>::value>; | 
| 78 |  | 
| 79 | template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase { | 
| 80 | protected: | 
| 81 |   static constexpr size_t InlineStorageSize = sizeof(void *) * 3; | 
| 82 |  | 
| 83 |   template <typename T, class = void> | 
| 84 |   struct IsSizeLessThanThresholdT : std::false_type {}; | 
| 85 |  | 
| 86 |   template <typename T> | 
| 87 |   struct IsSizeLessThanThresholdT< | 
| 88 |       T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {}; | 
| 89 |  | 
| 90 |   // Provide a type function to map parameters that won't observe extra copies | 
| 91 |   // or moves and which are small enough to likely pass in register to values | 
| 92 |   // and all other types to l-value reference types. We use this to compute the | 
| 93 |   // types used in our erased call utility to minimize copies and moves unless | 
| 94 |   // doing so would force things unnecessarily into memory. | 
| 95 |   // | 
| 96 |   // The heuristic used is related to common ABI register passing conventions. | 
| 97 |   // It doesn't have to be exact though, and in one way it is more strict | 
| 98 |   // because we want to still be able to observe either moves *or* copies. | 
| 99 |   template <typename T> struct AdjustedParamTBase { | 
| 100 |     static_assert(!std::is_reference<T>::value, | 
| 101 |                   "references should be handled by template specialization" ); | 
| 102 |     using type = | 
| 103 |         std::conditional_t<std::is_trivially_copy_constructible<T>::value && | 
| 104 |                                std::is_trivially_move_constructible<T>::value && | 
| 105 |                                IsSizeLessThanThresholdT<T>::value, | 
| 106 |                            T, T &>; | 
| 107 |   }; | 
| 108 |  | 
| 109 |   // This specialization ensures that 'AdjustedParam<V<T>&>' or | 
| 110 |   // 'AdjustedParam<V<T>&&>' does not trigger a compile-time error when 'T' is | 
| 111 |   // an incomplete type and V a templated type. | 
| 112 |   template <typename T> struct AdjustedParamTBase<T &> { using type = T &; }; | 
| 113 |   template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; }; | 
| 114 |  | 
| 115 |   template <typename T> | 
| 116 |   using AdjustedParamT = typename AdjustedParamTBase<T>::type; | 
| 117 |  | 
| 118 |   // The type of the erased function pointer we use as a callback to dispatch to | 
| 119 |   // the stored callable when it is trivial to move and destroy. | 
| 120 |   using CallPtrT = ReturnT (*)(void *CallableAddr, | 
| 121 |                                AdjustedParamT<ParamTs>... Params); | 
| 122 |   using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr); | 
| 123 |   using DestroyPtrT = void (*)(void *CallableAddr); | 
| 124 |  | 
| 125 |   /// A struct to hold a single trivial callback with sufficient alignment for | 
| 126 |   /// our bitpacking. | 
| 127 |   struct alignas(8) TrivialCallback { | 
| 128 |     CallPtrT CallPtr; | 
| 129 |   }; | 
| 130 |  | 
| 131 |   /// A struct we use to aggregate three callbacks when we need full set of | 
| 132 |   /// operations. | 
| 133 |   struct alignas(8) NonTrivialCallbacks { | 
| 134 |     CallPtrT CallPtr; | 
| 135 |     MovePtrT MovePtr; | 
| 136 |     DestroyPtrT DestroyPtr; | 
| 137 |   }; | 
| 138 |  | 
| 139 |   // Create a pointer union between either a pointer to a static trivial call | 
| 140 |   // pointer in a struct or a pointer to a static struct of the call, move, and | 
| 141 |   // destroy pointers. | 
| 142 |   using CallbackPointerUnionT = | 
| 143 |       PointerUnion<TrivialCallback *, NonTrivialCallbacks *>; | 
| 144 |  | 
| 145 |   // The main storage buffer. This will either have a pointer to out-of-line | 
| 146 |   // storage or an inline buffer storing the callable. | 
| 147 |   union StorageUnionT { | 
| 148 |     // For out-of-line storage we keep a pointer to the underlying storage and | 
| 149 |     // the size. This is enough to deallocate the memory. | 
| 150 |     struct OutOfLineStorageT { | 
| 151 |       void *StoragePtr; | 
| 152 |       size_t Size; | 
| 153 |       size_t Alignment; | 
| 154 |     } OutOfLineStorage; | 
| 155 |     static_assert( | 
| 156 |         sizeof(OutOfLineStorageT) <= InlineStorageSize, | 
| 157 |         "Should always use all of the out-of-line storage for inline storage!" ); | 
| 158 |  | 
| 159 |     // For in-line storage, we just provide an aligned character buffer. We | 
| 160 |     // provide three pointers worth of storage here. | 
| 161 |     // This is mutable as an inlined `const unique_function<void() const>` may | 
| 162 |     // still modify its own mutable members. | 
| 163 |     mutable std::aligned_storage_t<InlineStorageSize, alignof(void *)> | 
| 164 |         InlineStorage; | 
| 165 |   } StorageUnion; | 
| 166 |  | 
| 167 |   // A compressed pointer to either our dispatching callback or our table of | 
| 168 |   // dispatching callbacks and the flag for whether the callable itself is | 
| 169 |   // stored inline or not. | 
| 170 |   PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag; | 
| 171 |  | 
| 172 |   bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); } | 
| 173 |  | 
| 174 |   bool isTrivialCallback() const { | 
| 175 |     return isa<TrivialCallback *>(CallbackAndInlineFlag.getPointer()); | 
| 176 |   } | 
| 177 |  | 
| 178 |   CallPtrT getTrivialCallback() const { | 
| 179 |     return cast<TrivialCallback *>(CallbackAndInlineFlag.getPointer())->CallPtr; | 
| 180 |   } | 
| 181 |  | 
| 182 |   NonTrivialCallbacks *getNonTrivialCallbacks() const { | 
| 183 |     return cast<NonTrivialCallbacks *>(CallbackAndInlineFlag.getPointer()); | 
| 184 |   } | 
| 185 |  | 
| 186 |   CallPtrT getCallPtr() const { | 
| 187 |     return isTrivialCallback() ? getTrivialCallback() | 
| 188 |                                : getNonTrivialCallbacks()->CallPtr; | 
| 189 |   } | 
| 190 |  | 
| 191 |   // These three functions are only const in the narrow sense. They return | 
| 192 |   // mutable pointers to function state. | 
| 193 |   // This allows unique_function<T const>::operator() to be const, even if the | 
| 194 |   // underlying functor may be internally mutable. | 
| 195 |   // | 
| 196 |   // const callers must ensure they're only used in const-correct ways. | 
| 197 |   void *getCalleePtr() const { | 
| 198 |     return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); | 
| 199 |   } | 
| 200 |   void *getInlineStorage() const { return &StorageUnion.InlineStorage; } | 
| 201 |   void *getOutOfLineStorage() const { | 
| 202 |     return StorageUnion.OutOfLineStorage.StoragePtr; | 
| 203 |   } | 
| 204 |  | 
| 205 |   size_t getOutOfLineStorageSize() const { | 
| 206 |     return StorageUnion.OutOfLineStorage.Size; | 
| 207 |   } | 
| 208 |   size_t getOutOfLineStorageAlignment() const { | 
| 209 |     return StorageUnion.OutOfLineStorage.Alignment; | 
| 210 |   } | 
| 211 |  | 
| 212 |   void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) { | 
| 213 |     StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment}; | 
| 214 |   } | 
| 215 |  | 
| 216 |   template <typename CalledAsT> | 
| 217 |   static ReturnT CallImpl(void *CallableAddr, | 
| 218 |                           AdjustedParamT<ParamTs>... Params) { | 
| 219 |     auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr); | 
| 220 |     return Func(std::forward<ParamTs>(Params)...); | 
| 221 |   } | 
| 222 |  | 
| 223 |   template <typename CallableT> | 
| 224 |   static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept { | 
| 225 |     new (LHSCallableAddr) | 
| 226 |         CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr))); | 
| 227 |   } | 
| 228 |  | 
| 229 |   template <typename CallableT> | 
| 230 |   static void DestroyImpl(void *CallableAddr) noexcept { | 
| 231 |     reinterpret_cast<CallableT *>(CallableAddr)->~CallableT(); | 
| 232 |   } | 
| 233 |  | 
| 234 |   // The pointers to call/move/destroy functions are determined for each | 
| 235 |   // callable type (and called-as type, which determines the overload chosen). | 
| 236 |   // (definitions are out-of-line). | 
| 237 |  | 
| 238 |   // By default, we need an object that contains all the different | 
| 239 |   // type erased behaviors needed. Create a static instance of the struct type | 
| 240 |   // here and each instance will contain a pointer to it. | 
| 241 |   // Wrap in a struct to avoid https://gcc.gnu.org/PR71954 | 
| 242 |   template <typename CallableT, typename CalledAs, typename Enable = void> | 
| 243 |   struct CallbacksHolder { | 
| 244 |     static NonTrivialCallbacks Callbacks; | 
| 245 |   }; | 
| 246 |   // See if we can create a trivial callback. We need the callable to be | 
| 247 |   // trivially moved and trivially destroyed so that we don't have to store | 
| 248 |   // type erased callbacks for those operations. | 
| 249 |   template <typename CallableT, typename CalledAs> | 
| 250 |   struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> { | 
| 251 |     static TrivialCallback Callbacks; | 
| 252 |   }; | 
| 253 |  | 
| 254 |   // A simple tag type so the call-as type to be passed to the constructor. | 
| 255 |   template <typename T> struct CalledAs {}; | 
| 256 |  | 
| 257 |   // Essentially the "main" unique_function constructor, but subclasses | 
| 258 |   // provide the qualified type to be used for the call. | 
| 259 |   // (We always store a T, even if the call will use a pointer to const T). | 
| 260 |   template <typename CallableT, typename CalledAsT> | 
| 261 |   UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) { | 
| 262 |     bool IsInlineStorage = true; | 
| 263 |     void *CallableAddr = getInlineStorage(); | 
| 264 |     if (sizeof(CallableT) > InlineStorageSize || | 
| 265 |         alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) { | 
| 266 |       IsInlineStorage = false; | 
| 267 |       // Allocate out-of-line storage. FIXME: Use an explicit alignment | 
| 268 |       // parameter in C++17 mode. | 
| 269 |       auto Size = sizeof(CallableT); | 
| 270 |       auto Alignment = alignof(CallableT); | 
| 271 |       CallableAddr = allocate_buffer(Size, Alignment); | 
| 272 |       setOutOfLineStorage(Ptr: CallableAddr, Size, Alignment); | 
| 273 |     } | 
| 274 |  | 
| 275 |     // Now move into the storage. | 
| 276 |     new (CallableAddr) CallableT(std::move(Callable)); | 
| 277 |     CallbackAndInlineFlag.setPointerAndInt( | 
| 278 |         &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage); | 
| 279 |   } | 
| 280 |  | 
| 281 |   ~UniqueFunctionBase() { | 
| 282 |     if (!CallbackAndInlineFlag.getPointer()) | 
| 283 |       return; | 
| 284 |  | 
| 285 |     // Cache this value so we don't re-check it after type-erased operations. | 
| 286 |     bool IsInlineStorage = isInlineStorage(); | 
| 287 |  | 
| 288 |     if (!isTrivialCallback()) | 
| 289 |       getNonTrivialCallbacks()->DestroyPtr( | 
| 290 |           IsInlineStorage ? getInlineStorage() : getOutOfLineStorage()); | 
| 291 |  | 
| 292 |     if (!IsInlineStorage) | 
| 293 |       deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(), | 
| 294 |                         getOutOfLineStorageAlignment()); | 
| 295 |   } | 
| 296 |  | 
| 297 |   UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept { | 
| 298 |     // Copy the callback and inline flag. | 
| 299 |     CallbackAndInlineFlag = RHS.CallbackAndInlineFlag; | 
| 300 |  | 
| 301 |     // If the RHS is empty, just copying the above is sufficient. | 
| 302 |     if (!RHS) | 
| 303 |       return; | 
| 304 |  | 
| 305 |     if (!isInlineStorage()) { | 
| 306 |       // The out-of-line case is easiest to move. | 
| 307 |       StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage; | 
| 308 |     } else if (isTrivialCallback()) { | 
| 309 |       // Move is trivial, just memcpy the bytes across. | 
| 310 |       memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize); | 
| 311 |     } else { | 
| 312 |       // Non-trivial move, so dispatch to a type-erased implementation. | 
| 313 |       getNonTrivialCallbacks()->MovePtr(getInlineStorage(), | 
| 314 |                                         RHS.getInlineStorage()); | 
| 315 |     } | 
| 316 |  | 
| 317 |     // Clear the old callback and inline flag to get back to as-if-null. | 
| 318 |     RHS.CallbackAndInlineFlag = {}; | 
| 319 |  | 
| 320 | #ifndef NDEBUG | 
| 321 |     // In debug builds, we also scribble across the rest of the storage. | 
| 322 |     memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize); | 
| 323 | #endif | 
| 324 |   } | 
| 325 |  | 
| 326 |   UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept { | 
| 327 |     if (this == &RHS) | 
| 328 |       return *this; | 
| 329 |  | 
| 330 |     // Because we don't try to provide any exception safety guarantees we can | 
| 331 |     // implement move assignment very simply by first destroying the current | 
| 332 |     // object and then move-constructing over top of it. | 
| 333 |     this->~UniqueFunctionBase(); | 
| 334 |     new (this) UniqueFunctionBase(std::move(RHS)); | 
| 335 |     return *this; | 
| 336 |   } | 
| 337 |  | 
| 338 |   UniqueFunctionBase() = default; | 
| 339 |  | 
| 340 | public: | 
| 341 |   explicit operator bool() const { | 
| 342 |     return (bool)CallbackAndInlineFlag.getPointer(); | 
| 343 |   } | 
| 344 | }; | 
| 345 |  | 
| 346 | template <typename R, typename... P> | 
| 347 | template <typename CallableT, typename CalledAsT, typename Enable> | 
| 348 | typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase< | 
| 349 |     R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = { | 
| 350 |     &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>}; | 
| 351 |  | 
| 352 | template <typename R, typename... P> | 
| 353 | template <typename CallableT, typename CalledAsT> | 
| 354 | typename UniqueFunctionBase<R, P...>::TrivialCallback | 
| 355 |     UniqueFunctionBase<R, P...>::CallbacksHolder< | 
| 356 |         CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{ | 
| 357 |         &CallImpl<CalledAsT>}; | 
| 358 |  | 
| 359 | } // namespace detail | 
| 360 |  | 
| 361 | template <typename R, typename... P> | 
| 362 | class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> { | 
| 363 |   using Base = detail::UniqueFunctionBase<R, P...>; | 
| 364 |  | 
| 365 | public: | 
| 366 |   unique_function() = default; | 
| 367 |   unique_function(std::nullptr_t) {} | 
| 368 |   unique_function(unique_function &&) = default; | 
| 369 |   unique_function(const unique_function &) = delete; | 
| 370 |   unique_function &operator=(unique_function &&) = default; | 
| 371 |   unique_function &operator=(const unique_function &) = delete; | 
| 372 |  | 
| 373 |   template <typename CallableT> | 
| 374 |   unique_function( | 
| 375 |       CallableT Callable, | 
| 376 |       detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, | 
| 377 |       detail::EnableIfCallable<CallableT, R, P...> * = nullptr) | 
| 378 |       : Base(std::forward<CallableT>(Callable), | 
| 379 |              typename Base::template CalledAs<CallableT>{}) {} | 
| 380 |  | 
| 381 |   R operator()(P... Params) { | 
| 382 |     return this->getCallPtr()(this->getCalleePtr(), Params...); | 
| 383 |   } | 
| 384 | }; | 
| 385 |  | 
| 386 | template <typename R, typename... P> | 
| 387 | class unique_function<R(P...) const> | 
| 388 |     : public detail::UniqueFunctionBase<R, P...> { | 
| 389 |   using Base = detail::UniqueFunctionBase<R, P...>; | 
| 390 |  | 
| 391 | public: | 
| 392 |   unique_function() = default; | 
| 393 |   unique_function(std::nullptr_t) {} | 
| 394 |   unique_function(unique_function &&) = default; | 
| 395 |   unique_function(const unique_function &) = delete; | 
| 396 |   unique_function &operator=(unique_function &&) = default; | 
| 397 |   unique_function &operator=(const unique_function &) = delete; | 
| 398 |  | 
| 399 |   template <typename CallableT> | 
| 400 |   unique_function( | 
| 401 |       CallableT Callable, | 
| 402 |       detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr, | 
| 403 |       detail::EnableIfCallable<const CallableT, R, P...> * = nullptr) | 
| 404 |       : Base(std::forward<CallableT>(Callable), | 
| 405 |              typename Base::template CalledAs<const CallableT>{}) {} | 
| 406 |  | 
| 407 |   R operator()(P... Params) const { | 
| 408 |     return this->getCallPtr()(this->getCalleePtr(), Params...); | 
| 409 |   } | 
| 410 | }; | 
| 411 |  | 
| 412 | } // end namespace llvm | 
| 413 |  | 
| 414 | #endif // LLVM_ADT_FUNCTIONEXTRAS_H | 
| 415 |  |