| 1 | #![allow (clippy::unit_arg)] |
| 2 | |
| 3 | use std::cmp; |
| 4 | use std::fmt; |
| 5 | use std::marker::PhantomData; |
| 6 | use std::mem; |
| 7 | use std::num::NonZeroUsize; |
| 8 | |
| 9 | use crate::errors::InvalidThreadAccess; |
| 10 | use crate::registry; |
| 11 | use crate::thread_id; |
| 12 | use crate::StackToken; |
| 13 | |
| 14 | /// A [`Sticky<T>`] keeps a value T stored in a thread. |
| 15 | /// |
| 16 | /// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a |
| 17 | /// similar interface. The difference is that whereas [`Fragile`](crate::Fragile) has |
| 18 | /// its destructor called in the thread where the value was sent, a |
| 19 | /// [`Sticky`] that is moved to another thread will have the internal |
| 20 | /// destructor called when the originating thread tears down. |
| 21 | /// |
| 22 | /// Because [`Sticky`] allows values to be kept alive for longer than the |
| 23 | /// [`Sticky`] itself, it requires all its contents to be `'static` for |
| 24 | /// soundness. More importantly it also requires the use of [`StackToken`]s. |
| 25 | /// For information about how to use stack tokens and why they are neded, |
| 26 | /// refer to [`stack_token!`](crate::stack_token). |
| 27 | /// |
| 28 | /// As this uses TLS internally the general rules about the platform limitations |
| 29 | /// of destructors for TLS apply. |
| 30 | pub struct Sticky<T: 'static> { |
| 31 | item_id: registry::ItemId, |
| 32 | thread_id: NonZeroUsize, |
| 33 | _marker: PhantomData<*mut T>, |
| 34 | } |
| 35 | |
| 36 | impl<T> Drop for Sticky<T> { |
| 37 | fn drop(&mut self) { |
| 38 | // if the type needs dropping we can only do so on the |
| 39 | // right thread. worst case we leak the value until the |
| 40 | // thread dies. |
| 41 | if mem::needs_drop::<T>() { |
| 42 | unsafe { |
| 43 | if self.is_valid() { |
| 44 | self.unsafe_take_value(); |
| 45 | } |
| 46 | } |
| 47 | |
| 48 | // otherwise we take the liberty to drop the value |
| 49 | // right here and now. We can however only do that if |
| 50 | // we are on the right thread. If we are not, we again |
| 51 | // need to wait for the thread to shut down. |
| 52 | } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) { |
| 53 | unsafe { |
| 54 | (entry.drop)(entry.ptr); |
| 55 | } |
| 56 | } |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | impl<T> Sticky<T> { |
| 61 | /// Creates a new [`Sticky`] wrapping a `value`. |
| 62 | /// |
| 63 | /// The value that is moved into the [`Sticky`] can be non `Send` and |
| 64 | /// will be anchored to the thread that created the object. If the |
| 65 | /// sticky wrapper type ends up being send from thread to thread |
| 66 | /// only the original thread can interact with the value. |
| 67 | pub fn new(value: T) -> Self { |
| 68 | let entry = registry::Entry { |
| 69 | ptr: Box::into_raw(Box::new(value)).cast(), |
| 70 | drop: |ptr| { |
| 71 | let ptr = ptr.cast::<T>(); |
| 72 | // SAFETY: This callback will only be called once, with the |
| 73 | // above pointer. |
| 74 | drop(unsafe { Box::from_raw(ptr) }); |
| 75 | }, |
| 76 | }; |
| 77 | |
| 78 | let thread_id = thread_id::get(); |
| 79 | let item_id = registry::insert(thread_id, entry); |
| 80 | |
| 81 | Sticky { |
| 82 | item_id, |
| 83 | thread_id, |
| 84 | _marker: PhantomData, |
| 85 | } |
| 86 | } |
| 87 | |
| 88 | #[inline (always)] |
| 89 | fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R { |
| 90 | self.assert_thread(); |
| 91 | |
| 92 | registry::with(self.item_id, self.thread_id, |entry| { |
| 93 | f(entry.ptr.cast::<T>()) |
| 94 | }) |
| 95 | } |
| 96 | |
| 97 | /// Returns `true` if the access is valid. |
| 98 | /// |
| 99 | /// This will be `false` if the value was sent to another thread. |
| 100 | #[inline (always)] |
| 101 | pub fn is_valid(&self) -> bool { |
| 102 | thread_id::get() == self.thread_id |
| 103 | } |
| 104 | |
| 105 | #[inline (always)] |
| 106 | fn assert_thread(&self) { |
| 107 | if !self.is_valid() { |
| 108 | panic!("trying to access wrapped value in sticky container from incorrect thread." ); |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | /// Consumes the `Sticky`, returning the wrapped value. |
| 113 | /// |
| 114 | /// # Panics |
| 115 | /// |
| 116 | /// Panics if called from a different thread than the one where the |
| 117 | /// original value was created. |
| 118 | pub fn into_inner(mut self) -> T { |
| 119 | self.assert_thread(); |
| 120 | unsafe { |
| 121 | let rv = self.unsafe_take_value(); |
| 122 | mem::forget(self); |
| 123 | rv |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | unsafe fn unsafe_take_value(&mut self) -> T { |
| 128 | let ptr = registry::remove(self.item_id, self.thread_id) |
| 129 | .ptr |
| 130 | .cast::<T>(); |
| 131 | *Box::from_raw(ptr) |
| 132 | } |
| 133 | |
| 134 | /// Consumes the `Sticky`, returning the wrapped value if successful. |
| 135 | /// |
| 136 | /// The wrapped value is returned if this is called from the same thread |
| 137 | /// as the one where the original value was created, otherwise the |
| 138 | /// `Sticky` is returned as `Err(self)`. |
| 139 | pub fn try_into_inner(self) -> Result<T, Self> { |
| 140 | if self.is_valid() { |
| 141 | Ok(self.into_inner()) |
| 142 | } else { |
| 143 | Err(self) |
| 144 | } |
| 145 | } |
| 146 | |
| 147 | /// Immutably borrows the wrapped value. |
| 148 | /// |
| 149 | /// # Panics |
| 150 | /// |
| 151 | /// Panics if the calling thread is not the one that wrapped the value. |
| 152 | /// For a non-panicking variant, use [`try_get`](#method.try_get`). |
| 153 | pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T { |
| 154 | self.with_value(|value| unsafe { &*value }) |
| 155 | } |
| 156 | |
| 157 | /// Mutably borrows the wrapped value. |
| 158 | /// |
| 159 | /// # Panics |
| 160 | /// |
| 161 | /// Panics if the calling thread is not the one that wrapped the value. |
| 162 | /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`). |
| 163 | pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T { |
| 164 | self.with_value(|value| unsafe { &mut *value }) |
| 165 | } |
| 166 | |
| 167 | /// Tries to immutably borrow the wrapped value. |
| 168 | /// |
| 169 | /// Returns `None` if the calling thread is not the one that wrapped the value. |
| 170 | pub fn try_get<'stack>( |
| 171 | &'stack self, |
| 172 | _proof: &'stack StackToken, |
| 173 | ) -> Result<&'stack T, InvalidThreadAccess> { |
| 174 | if self.is_valid() { |
| 175 | Ok(self.with_value(|value| unsafe { &*value })) |
| 176 | } else { |
| 177 | Err(InvalidThreadAccess) |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | /// Tries to mutably borrow the wrapped value. |
| 182 | /// |
| 183 | /// Returns `None` if the calling thread is not the one that wrapped the value. |
| 184 | pub fn try_get_mut<'stack>( |
| 185 | &'stack mut self, |
| 186 | _proof: &'stack StackToken, |
| 187 | ) -> Result<&'stack mut T, InvalidThreadAccess> { |
| 188 | if self.is_valid() { |
| 189 | Ok(self.with_value(|value| unsafe { &mut *value })) |
| 190 | } else { |
| 191 | Err(InvalidThreadAccess) |
| 192 | } |
| 193 | } |
| 194 | } |
| 195 | |
| 196 | impl<T> From<T> for Sticky<T> { |
| 197 | #[inline ] |
| 198 | fn from(t: T) -> Sticky<T> { |
| 199 | Sticky::new(t) |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | impl<T: Clone> Clone for Sticky<T> { |
| 204 | #[inline ] |
| 205 | fn clone(&self) -> Sticky<T> { |
| 206 | crate::stack_token!(tok); |
| 207 | Sticky::new(self.get(tok).clone()) |
| 208 | } |
| 209 | } |
| 210 | |
| 211 | impl<T: Default> Default for Sticky<T> { |
| 212 | #[inline ] |
| 213 | fn default() -> Sticky<T> { |
| 214 | Sticky::new(T::default()) |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | impl<T: PartialEq> PartialEq for Sticky<T> { |
| 219 | #[inline ] |
| 220 | fn eq(&self, other: &Sticky<T>) -> bool { |
| 221 | crate::stack_token!(tok); |
| 222 | *self.get(tok) == *other.get(tok) |
| 223 | } |
| 224 | } |
| 225 | |
| 226 | impl<T: Eq> Eq for Sticky<T> {} |
| 227 | |
| 228 | impl<T: PartialOrd> PartialOrd for Sticky<T> { |
| 229 | #[inline ] |
| 230 | fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> { |
| 231 | crate::stack_token!(tok); |
| 232 | self.get(tok).partial_cmp(other.get(tok)) |
| 233 | } |
| 234 | |
| 235 | #[inline ] |
| 236 | fn lt(&self, other: &Sticky<T>) -> bool { |
| 237 | crate::stack_token!(tok); |
| 238 | *self.get(tok) < *other.get(tok) |
| 239 | } |
| 240 | |
| 241 | #[inline ] |
| 242 | fn le(&self, other: &Sticky<T>) -> bool { |
| 243 | crate::stack_token!(tok); |
| 244 | *self.get(tok) <= *other.get(tok) |
| 245 | } |
| 246 | |
| 247 | #[inline ] |
| 248 | fn gt(&self, other: &Sticky<T>) -> bool { |
| 249 | crate::stack_token!(tok); |
| 250 | *self.get(tok) > *other.get(tok) |
| 251 | } |
| 252 | |
| 253 | #[inline ] |
| 254 | fn ge(&self, other: &Sticky<T>) -> bool { |
| 255 | crate::stack_token!(tok); |
| 256 | *self.get(tok) >= *other.get(tok) |
| 257 | } |
| 258 | } |
| 259 | |
| 260 | impl<T: Ord> Ord for Sticky<T> { |
| 261 | #[inline ] |
| 262 | fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering { |
| 263 | crate::stack_token!(tok); |
| 264 | self.get(tok).cmp(other.get(tok)) |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | impl<T: fmt::Display> fmt::Display for Sticky<T> { |
| 269 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { |
| 270 | crate::stack_token!(tok); |
| 271 | fmt::Display::fmt(self.get(tok), f) |
| 272 | } |
| 273 | } |
| 274 | |
| 275 | impl<T: fmt::Debug> fmt::Debug for Sticky<T> { |
| 276 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { |
| 277 | crate::stack_token!(tok); |
| 278 | match self.try_get(tok) { |
| 279 | Ok(value) => f.debug_struct("Sticky" ).field("value" , value).finish(), |
| 280 | Err(..) => { |
| 281 | struct InvalidPlaceholder; |
| 282 | impl fmt::Debug for InvalidPlaceholder { |
| 283 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 284 | f.write_str("<invalid thread>" ) |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | f.debug_struct("Sticky" ) |
| 289 | .field("value" , &InvalidPlaceholder) |
| 290 | .finish() |
| 291 | } |
| 292 | } |
| 293 | } |
| 294 | } |
| 295 | |
| 296 | // similar as for fragile ths type is sync because it only accesses TLS data |
| 297 | // which is thread local. There is nothing that needs to be synchronized. |
| 298 | unsafe impl<T> Sync for Sticky<T> {} |
| 299 | |
| 300 | // The entire point of this type is to be Send |
| 301 | unsafe impl<T> Send for Sticky<T> {} |
| 302 | |
| 303 | #[test] |
| 304 | fn test_basic() { |
| 305 | use std::thread; |
| 306 | let val = Sticky::new(true); |
| 307 | crate::stack_token!(tok); |
| 308 | assert_eq!(val.to_string(), "true" ); |
| 309 | assert_eq!(val.get(tok), &true); |
| 310 | assert!(val.try_get(tok).is_ok()); |
| 311 | thread::spawn(move || { |
| 312 | crate::stack_token!(tok); |
| 313 | assert!(val.try_get(tok).is_err()); |
| 314 | }) |
| 315 | .join() |
| 316 | .unwrap(); |
| 317 | } |
| 318 | |
| 319 | #[test] |
| 320 | fn test_mut() { |
| 321 | let mut val = Sticky::new(true); |
| 322 | crate::stack_token!(tok); |
| 323 | *val.get_mut(tok) = false; |
| 324 | assert_eq!(val.to_string(), "false" ); |
| 325 | assert_eq!(val.get(tok), &false); |
| 326 | } |
| 327 | |
| 328 | #[test] |
| 329 | #[should_panic ] |
| 330 | fn test_access_other_thread() { |
| 331 | use std::thread; |
| 332 | let val = Sticky::new(true); |
| 333 | thread::spawn(move || { |
| 334 | crate::stack_token!(tok); |
| 335 | val.get(tok); |
| 336 | }) |
| 337 | .join() |
| 338 | .unwrap(); |
| 339 | } |
| 340 | |
| 341 | #[test] |
| 342 | fn test_drop_same_thread() { |
| 343 | use std::sync::atomic::{AtomicBool, Ordering}; |
| 344 | use std::sync::Arc; |
| 345 | let was_called = Arc::new(AtomicBool::new(false)); |
| 346 | struct X(Arc<AtomicBool>); |
| 347 | impl Drop for X { |
| 348 | fn drop(&mut self) { |
| 349 | self.0.store(true, Ordering::SeqCst); |
| 350 | } |
| 351 | } |
| 352 | let val = Sticky::new(X(was_called.clone())); |
| 353 | mem::drop(val); |
| 354 | assert!(was_called.load(Ordering::SeqCst)); |
| 355 | } |
| 356 | |
| 357 | #[test] |
| 358 | fn test_noop_drop_elsewhere() { |
| 359 | use std::sync::atomic::{AtomicBool, Ordering}; |
| 360 | use std::sync::Arc; |
| 361 | use std::thread; |
| 362 | |
| 363 | let was_called = Arc::new(AtomicBool::new(false)); |
| 364 | |
| 365 | { |
| 366 | let was_called = was_called.clone(); |
| 367 | thread::spawn(move || { |
| 368 | struct X(Arc<AtomicBool>); |
| 369 | impl Drop for X { |
| 370 | fn drop(&mut self) { |
| 371 | self.0.store(true, Ordering::SeqCst); |
| 372 | } |
| 373 | } |
| 374 | |
| 375 | let val = Sticky::new(X(was_called.clone())); |
| 376 | assert!(thread::spawn(move || { |
| 377 | // moves it here but do not deallocate |
| 378 | crate::stack_token!(tok); |
| 379 | val.try_get(tok).ok(); |
| 380 | }) |
| 381 | .join() |
| 382 | .is_ok()); |
| 383 | |
| 384 | assert!(!was_called.load(Ordering::SeqCst)); |
| 385 | }) |
| 386 | .join() |
| 387 | .unwrap(); |
| 388 | } |
| 389 | |
| 390 | assert!(was_called.load(Ordering::SeqCst)); |
| 391 | } |
| 392 | |
| 393 | #[test] |
| 394 | fn test_rc_sending() { |
| 395 | use std::rc::Rc; |
| 396 | use std::thread; |
| 397 | let val = Sticky::new(Rc::new(true)); |
| 398 | thread::spawn(move || { |
| 399 | crate::stack_token!(tok); |
| 400 | assert!(val.try_get(tok).is_err()); |
| 401 | }) |
| 402 | .join() |
| 403 | .unwrap(); |
| 404 | } |
| 405 | |
| 406 | #[test] |
| 407 | fn test_two_stickies() { |
| 408 | struct Wat; |
| 409 | |
| 410 | impl Drop for Wat { |
| 411 | fn drop(&mut self) { |
| 412 | // do nothing |
| 413 | } |
| 414 | } |
| 415 | |
| 416 | let s1 = Sticky::new(Wat); |
| 417 | let s2 = Sticky::new(Wat); |
| 418 | |
| 419 | // make sure all is well |
| 420 | |
| 421 | drop(s1); |
| 422 | drop(s2); |
| 423 | } |
| 424 | |