1// Copyright 2017 Amanieu d'Antras
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Per-object thread-local storage
9//!
10//! This library provides the `ThreadLocal` type which allows a separate copy of
11//! an object to be used for each thread. This allows for per-object
12//! thread-local storage, unlike the standard library's `thread_local!` macro
13//! which only allows static thread-local storage.
14//!
15//! Per-thread objects are not destroyed when a thread exits. Instead, objects
16//! are only destroyed when the `ThreadLocal` containing them is destroyed.
17//!
18//! You can also iterate over the thread-local values of all thread in a
19//! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can
20//! only be done if you have mutable access to the `ThreadLocal` object, which
21//! guarantees that you are the only thread currently accessing it.
22//!
23//! Note that since thread IDs are recycled when a thread exits, it is possible
24//! for one thread to retrieve the object of another thread. Since this can only
25//! occur after a thread has exited this does not lead to any race conditions.
26//!
27//! # Examples
28//!
29//! Basic usage of `ThreadLocal`:
30//!
31//! ```rust
32//! use thread_local::ThreadLocal;
33//! let tls: ThreadLocal<u32> = ThreadLocal::new();
34//! assert_eq!(tls.get(), None);
35//! assert_eq!(tls.get_or(|| 5), &5);
36//! assert_eq!(tls.get(), Some(&5));
37//! ```
38//!
39//! Combining thread-local values into a single result:
40//!
41//! ```rust
42//! use thread_local::ThreadLocal;
43//! use std::sync::Arc;
44//! use std::cell::Cell;
45//! use std::thread;
46//!
47//! let tls = Arc::new(ThreadLocal::new());
48//!
49//! // Create a bunch of threads to do stuff
50//! for _ in 0..5 {
51//! let tls2 = tls.clone();
52//! thread::spawn(move || {
53//! // Increment a counter to count some event...
54//! let cell = tls2.get_or(|| Cell::new(0));
55//! cell.set(cell.get() + 1);
56//! }).join().unwrap();
57//! }
58//!
59//! // Once all threads are done, collect the counter values and return the
60//! // sum of all thread-local counter values.
61//! let tls = Arc::try_unwrap(tls).unwrap();
62//! let total = tls.into_iter().fold(0, |x, y| x + y.get());
63//! assert_eq!(total, 5);
64//! ```
65
66#![warn(missing_docs)]
67#![allow(clippy::mutex_atomic)]
68#![cfg_attr(feature = "nightly", feature(thread_local))]
69
70mod cached;
71mod thread_id;
72mod unreachable;
73
74#[allow(deprecated)]
75pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
76
77use std::cell::UnsafeCell;
78use std::fmt;
79use std::iter::FusedIterator;
80use std::mem;
81use std::mem::MaybeUninit;
82use std::panic::UnwindSafe;
83use std::ptr;
84use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
85use thread_id::Thread;
86use unreachable::UncheckedResultExt;
87
88// Use usize::BITS once it has stabilized and the MSRV has been bumped.
89#[cfg(target_pointer_width = "16")]
90const POINTER_WIDTH: u8 = 16;
91#[cfg(target_pointer_width = "32")]
92const POINTER_WIDTH: u8 = 32;
93#[cfg(target_pointer_width = "64")]
94const POINTER_WIDTH: u8 = 64;
95
96/// The total number of buckets stored in each thread local.
97const BUCKETS: usize = (POINTER_WIDTH + 1) as usize;
98
99/// Thread-local variable wrapper
100///
101/// See the [module-level documentation](index.html) for more.
102pub struct ThreadLocal<T: Send> {
103 /// The buckets in the thread local. The nth bucket contains `2^(n-1)`
104 /// elements. Each bucket is lazily allocated.
105 buckets: [AtomicPtr<Entry<T>>; BUCKETS],
106
107 /// The number of values in the thread local. This can be less than the real number of values,
108 /// but is never more.
109 values: AtomicUsize,
110}
111
112struct Entry<T> {
113 present: AtomicBool,
114 value: UnsafeCell<MaybeUninit<T>>,
115}
116
117impl<T> Drop for Entry<T> {
118 fn drop(&mut self) {
119 unsafe {
120 if *self.present.get_mut() {
121 ptr::drop_in_place((*self.value.get()).as_mut_ptr());
122 }
123 }
124 }
125}
126
127// ThreadLocal is always Sync, even if T isn't
128unsafe impl<T: Send> Sync for ThreadLocal<T> {}
129
130impl<T: Send> Default for ThreadLocal<T> {
131 fn default() -> ThreadLocal<T> {
132 ThreadLocal::new()
133 }
134}
135
136impl<T: Send> Drop for ThreadLocal<T> {
137 fn drop(&mut self) {
138 let mut bucket_size = 1;
139
140 // Free each non-null bucket
141 for (i, bucket) in self.buckets.iter_mut().enumerate() {
142 let bucket_ptr = *bucket.get_mut();
143
144 let this_bucket_size = bucket_size;
145 if i != 0 {
146 bucket_size <<= 1;
147 }
148
149 if bucket_ptr.is_null() {
150 continue;
151 }
152
153 unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) };
154 }
155 }
156}
157
158impl<T: Send> ThreadLocal<T> {
159 /// Creates a new empty `ThreadLocal`.
160 pub fn new() -> ThreadLocal<T> {
161 Self::with_capacity(2)
162 }
163
164 /// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads
165 /// access the thread local it will never reallocate. The capacity may be rounded up to the
166 /// nearest power of two.
167 pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
168 let allocated_buckets = capacity
169 .checked_sub(1)
170 .map(|c| usize::from(POINTER_WIDTH) - (c.leading_zeros() as usize) + 1)
171 .unwrap_or(0);
172
173 let mut buckets = [ptr::null_mut(); BUCKETS];
174 let mut bucket_size = 1;
175 for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
176 *bucket = allocate_bucket::<T>(bucket_size);
177
178 if i != 0 {
179 bucket_size <<= 1;
180 }
181 }
182
183 ThreadLocal {
184 // Safety: AtomicPtr has the same representation as a pointer and arrays have the same
185 // representation as a sequence of their inner type.
186 buckets: unsafe { mem::transmute(buckets) },
187 values: AtomicUsize::new(0),
188 }
189 }
190
191 /// Returns the element for the current thread, if it exists.
192 pub fn get(&self) -> Option<&T> {
193 self.get_inner(thread_id::get())
194 }
195
196 /// Returns the element for the current thread, or creates it if it doesn't
197 /// exist.
198 pub fn get_or<F>(&self, create: F) -> &T
199 where
200 F: FnOnce() -> T,
201 {
202 unsafe {
203 self.get_or_try(|| Ok::<T, ()>(create()))
204 .unchecked_unwrap_ok()
205 }
206 }
207
208 /// Returns the element for the current thread, or creates it if it doesn't
209 /// exist. If `create` fails, that error is returned and no element is
210 /// added.
211 pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
212 where
213 F: FnOnce() -> Result<T, E>,
214 {
215 let thread = thread_id::get();
216 if let Some(val) = self.get_inner(thread) {
217 return Ok(val);
218 }
219
220 Ok(self.insert(create()?))
221 }
222
223 fn get_inner(&self, thread: Thread) -> Option<&T> {
224 let bucket_ptr =
225 unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
226 if bucket_ptr.is_null() {
227 return None;
228 }
229 unsafe {
230 let entry = &*bucket_ptr.add(thread.index);
231 // Read without atomic operations as only this thread can set the value.
232 if (&entry.present as *const _ as *const bool).read() {
233 Some(&*(&*entry.value.get()).as_ptr())
234 } else {
235 None
236 }
237 }
238 }
239
240 #[cold]
241 fn insert(&self, data: T) -> &T {
242 let thread = thread_id::get();
243 let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
244 let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
245
246 // If the bucket doesn't already exist, we need to allocate it
247 let bucket_ptr = if bucket_ptr.is_null() {
248 let new_bucket = allocate_bucket(thread.bucket_size);
249
250 match bucket_atomic_ptr.compare_exchange(
251 ptr::null_mut(),
252 new_bucket,
253 Ordering::AcqRel,
254 Ordering::Acquire,
255 ) {
256 Ok(_) => new_bucket,
257 // If the bucket value changed (from null), that means
258 // another thread stored a new bucket before we could,
259 // and we can free our bucket and use that one instead
260 Err(bucket_ptr) => {
261 unsafe { deallocate_bucket(new_bucket, thread.bucket_size) }
262 bucket_ptr
263 }
264 }
265 } else {
266 bucket_ptr
267 };
268
269 // Insert the new element into the bucket
270 let entry = unsafe { &*bucket_ptr.add(thread.index) };
271 let value_ptr = entry.value.get();
272 unsafe { value_ptr.write(MaybeUninit::new(data)) };
273 entry.present.store(true, Ordering::Release);
274
275 self.values.fetch_add(1, Ordering::Release);
276
277 unsafe { &*(&*value_ptr).as_ptr() }
278 }
279
280 /// Returns an iterator over the local values of all threads in unspecified
281 /// order.
282 ///
283 /// This call can be done safely, as `T` is required to implement [`Sync`].
284 pub fn iter(&self) -> Iter<'_, T>
285 where
286 T: Sync,
287 {
288 Iter {
289 thread_local: self,
290 raw: RawIter::new(),
291 }
292 }
293
294 /// Returns a mutable iterator over the local values of all threads in
295 /// unspecified order.
296 ///
297 /// Since this call borrows the `ThreadLocal` mutably, this operation can
298 /// be done safely---the mutable borrow statically guarantees no other
299 /// threads are currently accessing their associated values.
300 pub fn iter_mut(&mut self) -> IterMut<T> {
301 IterMut {
302 thread_local: self,
303 raw: RawIter::new(),
304 }
305 }
306
307 /// Removes all thread-specific values from the `ThreadLocal`, effectively
308 /// reseting it to its original state.
309 ///
310 /// Since this call borrows the `ThreadLocal` mutably, this operation can
311 /// be done safely---the mutable borrow statically guarantees no other
312 /// threads are currently accessing their associated values.
313 pub fn clear(&mut self) {
314 *self = ThreadLocal::new();
315 }
316}
317
318impl<T: Send> IntoIterator for ThreadLocal<T> {
319 type Item = T;
320 type IntoIter = IntoIter<T>;
321
322 fn into_iter(self) -> IntoIter<T> {
323 IntoIter {
324 thread_local: self,
325 raw: RawIter::new(),
326 }
327 }
328}
329
330impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
331 type Item = &'a T;
332 type IntoIter = Iter<'a, T>;
333
334 fn into_iter(self) -> Self::IntoIter {
335 self.iter()
336 }
337}
338
339impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
340 type Item = &'a mut T;
341 type IntoIter = IterMut<'a, T>;
342
343 fn into_iter(self) -> IterMut<'a, T> {
344 self.iter_mut()
345 }
346}
347
348impl<T: Send + Default> ThreadLocal<T> {
349 /// Returns the element for the current thread, or creates a default one if
350 /// it doesn't exist.
351 pub fn get_or_default(&self) -> &T {
352 self.get_or(Default::default)
353 }
354}
355
356impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
357 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
358 write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
359 }
360}
361
362impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
363
364#[derive(Debug)]
365struct RawIter {
366 yielded: usize,
367 bucket: usize,
368 bucket_size: usize,
369 index: usize,
370}
371impl RawIter {
372 #[inline]
373 fn new() -> Self {
374 Self {
375 yielded: 0,
376 bucket: 0,
377 bucket_size: 1,
378 index: 0,
379 }
380 }
381
382 fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
383 while self.bucket < BUCKETS {
384 let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
385 let bucket = bucket.load(Ordering::Acquire);
386
387 if !bucket.is_null() {
388 while self.index < self.bucket_size {
389 let entry = unsafe { &*bucket.add(self.index) };
390 self.index += 1;
391 if entry.present.load(Ordering::Acquire) {
392 self.yielded += 1;
393 return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
394 }
395 }
396 }
397
398 self.next_bucket();
399 }
400 None
401 }
402 fn next_mut<'a, T: Send>(
403 &mut self,
404 thread_local: &'a mut ThreadLocal<T>,
405 ) -> Option<&'a mut Entry<T>> {
406 if *thread_local.values.get_mut() == self.yielded {
407 return None;
408 }
409
410 loop {
411 let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
412 let bucket = *bucket.get_mut();
413
414 if !bucket.is_null() {
415 while self.index < self.bucket_size {
416 let entry = unsafe { &mut *bucket.add(self.index) };
417 self.index += 1;
418 if *entry.present.get_mut() {
419 self.yielded += 1;
420 return Some(entry);
421 }
422 }
423 }
424
425 self.next_bucket();
426 }
427 }
428
429 #[inline]
430 fn next_bucket(&mut self) {
431 if self.bucket != 0 {
432 self.bucket_size <<= 1;
433 }
434 self.bucket += 1;
435 self.index = 0;
436 }
437
438 fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
439 let total = thread_local.values.load(Ordering::Acquire);
440 (total - self.yielded, None)
441 }
442 fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
443 let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
444 let remaining = total - self.yielded;
445 (remaining, Some(remaining))
446 }
447}
448
449/// Iterator over the contents of a `ThreadLocal`.
450#[derive(Debug)]
451pub struct Iter<'a, T: Send + Sync> {
452 thread_local: &'a ThreadLocal<T>,
453 raw: RawIter,
454}
455
456impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
457 type Item = &'a T;
458 fn next(&mut self) -> Option<Self::Item> {
459 self.raw.next(self.thread_local)
460 }
461 fn size_hint(&self) -> (usize, Option<usize>) {
462 self.raw.size_hint(self.thread_local)
463 }
464}
465impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
466
467/// Mutable iterator over the contents of a `ThreadLocal`.
468pub struct IterMut<'a, T: Send> {
469 thread_local: &'a mut ThreadLocal<T>,
470 raw: RawIter,
471}
472
473impl<'a, T: Send> Iterator for IterMut<'a, T> {
474 type Item = &'a mut T;
475 fn next(&mut self) -> Option<&'a mut T> {
476 self.raw
477 .next_mut(self.thread_local)
478 .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
479 }
480 fn size_hint(&self) -> (usize, Option<usize>) {
481 self.raw.size_hint_frozen(self.thread_local)
482 }
483}
484
485impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
486impl<T: Send> FusedIterator for IterMut<'_, T> {}
487
488// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
489// this thread's value that potentially aliases with a mutable reference we have given out.
490impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
491 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492 f.debug_struct("IterMut").field("raw", &self.raw).finish()
493 }
494}
495
496/// An iterator that moves out of a `ThreadLocal`.
497#[derive(Debug)]
498pub struct IntoIter<T: Send> {
499 thread_local: ThreadLocal<T>,
500 raw: RawIter,
501}
502
503impl<T: Send> Iterator for IntoIter<T> {
504 type Item = T;
505 fn next(&mut self) -> Option<T> {
506 self.raw.next_mut(&mut self.thread_local).map(|entry| {
507 *entry.present.get_mut() = false;
508 unsafe {
509 std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
510 }
511 })
512 }
513 fn size_hint(&self) -> (usize, Option<usize>) {
514 self.raw.size_hint_frozen(&self.thread_local)
515 }
516}
517
518impl<T: Send> ExactSizeIterator for IntoIter<T> {}
519impl<T: Send> FusedIterator for IntoIter<T> {}
520
521fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
522 Box::into_raw(
523 (0..size)
524 .map(|_| Entry::<T> {
525 present: AtomicBool::new(false),
526 value: UnsafeCell::new(MaybeUninit::uninit()),
527 })
528 .collect(),
529 ) as *mut _
530}
531
532unsafe fn deallocate_bucket<T>(bucket: *mut Entry<T>, size: usize) {
533 let _ = Box::from_raw(std::slice::from_raw_parts_mut(bucket, size));
534}
535
536#[cfg(test)]
537mod tests {
538 use super::ThreadLocal;
539 use std::cell::RefCell;
540 use std::sync::atomic::AtomicUsize;
541 use std::sync::atomic::Ordering::Relaxed;
542 use std::sync::Arc;
543 use std::thread;
544
545 fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
546 let count = AtomicUsize::new(0);
547 Arc::new(move || count.fetch_add(1, Relaxed))
548 }
549
550 #[test]
551 fn same_thread() {
552 let create = make_create();
553 let mut tls = ThreadLocal::new();
554 assert_eq!(None, tls.get());
555 assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
556 assert_eq!(0, *tls.get_or(|| create()));
557 assert_eq!(Some(&0), tls.get());
558 assert_eq!(0, *tls.get_or(|| create()));
559 assert_eq!(Some(&0), tls.get());
560 assert_eq!(0, *tls.get_or(|| create()));
561 assert_eq!(Some(&0), tls.get());
562 assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
563 tls.clear();
564 assert_eq!(None, tls.get());
565 }
566
567 #[test]
568 fn different_thread() {
569 let create = make_create();
570 let tls = Arc::new(ThreadLocal::new());
571 assert_eq!(None, tls.get());
572 assert_eq!(0, *tls.get_or(|| create()));
573 assert_eq!(Some(&0), tls.get());
574
575 let tls2 = tls.clone();
576 let create2 = create.clone();
577 thread::spawn(move || {
578 assert_eq!(None, tls2.get());
579 assert_eq!(1, *tls2.get_or(|| create2()));
580 assert_eq!(Some(&1), tls2.get());
581 })
582 .join()
583 .unwrap();
584
585 assert_eq!(Some(&0), tls.get());
586 assert_eq!(0, *tls.get_or(|| create()));
587 }
588
589 #[test]
590 fn iter() {
591 let tls = Arc::new(ThreadLocal::new());
592 tls.get_or(|| Box::new(1));
593
594 let tls2 = tls.clone();
595 thread::spawn(move || {
596 tls2.get_or(|| Box::new(2));
597 let tls3 = tls2.clone();
598 thread::spawn(move || {
599 tls3.get_or(|| Box::new(3));
600 })
601 .join()
602 .unwrap();
603 drop(tls2);
604 })
605 .join()
606 .unwrap();
607
608 let mut tls = Arc::try_unwrap(tls).unwrap();
609
610 let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
611 v.sort_unstable();
612 assert_eq!(vec![1, 2, 3], v);
613
614 let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
615 v.sort_unstable();
616 assert_eq!(vec![1, 2, 3], v);
617
618 let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
619 v.sort_unstable();
620 assert_eq!(vec![1, 2, 3], v);
621 }
622
623 #[test]
624 fn test_drop() {
625 let local = ThreadLocal::new();
626 struct Dropped(Arc<AtomicUsize>);
627 impl Drop for Dropped {
628 fn drop(&mut self) {
629 self.0.fetch_add(1, Relaxed);
630 }
631 }
632
633 let dropped = Arc::new(AtomicUsize::new(0));
634 local.get_or(|| Dropped(dropped.clone()));
635 assert_eq!(dropped.load(Relaxed), 0);
636 drop(local);
637 assert_eq!(dropped.load(Relaxed), 1);
638 }
639
640 #[test]
641 fn is_sync() {
642 fn foo<T: Sync>() {}
643 foo::<ThreadLocal<String>>();
644 foo::<ThreadLocal<RefCell<String>>>();
645 }
646}
647