1//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a definition of the ThreadLocalCache class. This class
10// provides support for defining thread local objects with non-static duration.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15#define MLIR_SUPPORT_THREADLOCALCACHE_H
16
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/Support/ManagedStatic.h"
20#include "llvm/Support/Mutex.h"
21
22namespace mlir {
23/// This class provides support for defining a thread local object with non
24/// static storage duration. This is very useful for situations in which a data
25/// cache has very large lock contention.
26template <typename ValueT>
27class ThreadLocalCache {
28 // Keep a separate shared_ptr protected state that can be acquired atomically
29 // instead of using shared_ptr's for each value. This avoids a problem
30 // where the instance shared_ptr is locked() successfully, and then the
31 // ThreadLocalCache gets destroyed before remove() can be called successfully.
32 struct PerInstanceState {
33 /// Remove the given value entry. This is generally called when a thread
34 /// local cache is destructing.
35 void remove(ValueT *value) {
36 // Erase the found value directly, because it is guaranteed to be in the
37 // list.
38 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
39 auto it =
40 llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
41 return instance.get() == value;
42 });
43 assert(it != instances.end() && "expected value to exist in cache");
44 instances.erase(it);
45 }
46
47 /// Owning pointers to all of the values that have been constructed for this
48 /// object in the static cache.
49 SmallVector<std::unique_ptr<ValueT>, 1> instances;
50
51 /// A mutex used when a new thread instance has been added to the cache for
52 /// this object.
53 llvm::sys::SmartMutex<true> instanceMutex;
54 };
55
56 /// The type used for the static thread_local cache. This is a map between an
57 /// instance of the non-static cache and a weak reference to an instance of
58 /// ValueT. We use a weak reference here so that the object can be destroyed
59 /// without needing to lock access to the cache itself.
60 struct CacheType
61 : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
62 ~CacheType() {
63 // Remove the values of this cache that haven't already expired.
64 for (auto &it : *this)
65 if (std::shared_ptr<ValueT> value = it.second.lock())
66 it.first->remove(value.get());
67 }
68
69 /// Clear out any unused entries within the map. This method is not
70 /// thread-safe, and should only be called by the same thread as the cache.
71 void clearExpiredEntries() {
72 for (auto it = this->begin(), e = this->end(); it != e;) {
73 auto curIt = it++;
74 if (curIt->second.expired())
75 this->erase(curIt);
76 }
77 }
78 };
79
80public:
81 ThreadLocalCache() = default;
82 ~ThreadLocalCache() {
83 // No cleanup is necessary here as the shared_pointer memory will go out of
84 // scope and invalidate the weak pointers held by the thread_local caches.
85 }
86
87 /// Return an instance of the value type for the current thread.
88 ValueT &get() {
89 // Check for an already existing instance for this thread.
90 CacheType &staticCache = getStaticCache();
91 std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
92 if (std::shared_ptr<ValueT> value = threadInstance.lock())
93 return *value;
94
95 // Otherwise, create a new instance for this thread.
96 llvm::sys::SmartScopedLock<true> threadInstanceLock(
97 perInstanceState->instanceMutex);
98 perInstanceState->instances.push_back(std::make_unique<ValueT>());
99 ValueT *instance = perInstanceState->instances.back().get();
100 threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
101
102 // Before returning the new instance, take the chance to clear out any used
103 // entries in the static map. The cache is only cleared within the same
104 // thread to remove the need to lock the cache itself.
105 staticCache.clearExpiredEntries();
106 return *instance;
107 }
108 ValueT &operator*() { return get(); }
109 ValueT *operator->() { return &get(); }
110
111private:
112 ThreadLocalCache(ThreadLocalCache &&) = delete;
113 ThreadLocalCache(const ThreadLocalCache &) = delete;
114 ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
115
116 /// Return the static thread local instance of the cache type.
117 static CacheType &getStaticCache() {
118 static LLVM_THREAD_LOCAL CacheType cache;
119 return cache;
120 }
121
122 std::shared_ptr<PerInstanceState> perInstanceState =
123 std::make_shared<PerInstanceState>();
124};
125} // namespace mlir
126
127#endif // MLIR_SUPPORT_THREADLOCALCACHE_H
128

source code of mlir/include/mlir/Support/ThreadLocalCache.h