1 | #![allow (clippy::unit_arg)] |
2 | |
3 | use std::cmp; |
4 | use std::fmt; |
5 | use std::marker::PhantomData; |
6 | use std::mem; |
7 | use std::num::NonZeroUsize; |
8 | |
9 | use crate::errors::InvalidThreadAccess; |
10 | use crate::registry; |
11 | use crate::thread_id; |
12 | use crate::StackToken; |
13 | |
14 | /// A [`Sticky<T>`] keeps a value T stored in a thread. |
15 | /// |
16 | /// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a |
17 | /// similar interface. The difference is that whereas [`Fragile`](crate::Fragile) has |
18 | /// its destructor called in the thread where the value was sent, a |
19 | /// [`Sticky`] that is moved to another thread will have the internal |
20 | /// destructor called when the originating thread tears down. |
21 | /// |
22 | /// Because [`Sticky`] allows values to be kept alive for longer than the |
23 | /// [`Sticky`] itself, it requires all its contents to be `'static` for |
24 | /// soundness. More importantly it also requires the use of [`StackToken`]s. |
25 | /// For information about how to use stack tokens and why they are neded, |
26 | /// refer to [`stack_token!`](crate::stack_token). |
27 | /// |
28 | /// As this uses TLS internally the general rules about the platform limitations |
29 | /// of destructors for TLS apply. |
30 | pub struct Sticky<T: 'static> { |
31 | item_id: registry::ItemId, |
32 | thread_id: NonZeroUsize, |
33 | _marker: PhantomData<*mut T>, |
34 | } |
35 | |
36 | impl<T> Drop for Sticky<T> { |
37 | fn drop(&mut self) { |
38 | // if the type needs dropping we can only do so on the |
39 | // right thread. worst case we leak the value until the |
40 | // thread dies. |
41 | if mem::needs_drop::<T>() { |
42 | unsafe { |
43 | if self.is_valid() { |
44 | self.unsafe_take_value(); |
45 | } |
46 | } |
47 | |
48 | // otherwise we take the liberty to drop the value |
49 | // right here and now. We can however only do that if |
50 | // we are on the right thread. If we are not, we again |
51 | // need to wait for the thread to shut down. |
52 | } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) { |
53 | unsafe { |
54 | (entry.drop)(entry.ptr); |
55 | } |
56 | } |
57 | } |
58 | } |
59 | |
60 | impl<T> Sticky<T> { |
61 | /// Creates a new [`Sticky`] wrapping a `value`. |
62 | /// |
63 | /// The value that is moved into the [`Sticky`] can be non `Send` and |
64 | /// will be anchored to the thread that created the object. If the |
65 | /// sticky wrapper type ends up being send from thread to thread |
66 | /// only the original thread can interact with the value. |
67 | pub fn new(value: T) -> Self { |
68 | let entry = registry::Entry { |
69 | ptr: Box::into_raw(Box::new(value)).cast(), |
70 | drop: |ptr| { |
71 | let ptr = ptr.cast::<T>(); |
72 | // SAFETY: This callback will only be called once, with the |
73 | // above pointer. |
74 | drop(unsafe { Box::from_raw(ptr) }); |
75 | }, |
76 | }; |
77 | |
78 | let thread_id = thread_id::get(); |
79 | let item_id = registry::insert(thread_id, entry); |
80 | |
81 | Sticky { |
82 | item_id, |
83 | thread_id, |
84 | _marker: PhantomData, |
85 | } |
86 | } |
87 | |
88 | #[inline (always)] |
89 | fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R { |
90 | self.assert_thread(); |
91 | |
92 | registry::with(self.item_id, self.thread_id, |entry| { |
93 | f(entry.ptr.cast::<T>()) |
94 | }) |
95 | } |
96 | |
97 | /// Returns `true` if the access is valid. |
98 | /// |
99 | /// This will be `false` if the value was sent to another thread. |
100 | #[inline (always)] |
101 | pub fn is_valid(&self) -> bool { |
102 | thread_id::get() == self.thread_id |
103 | } |
104 | |
105 | #[inline (always)] |
106 | fn assert_thread(&self) { |
107 | if !self.is_valid() { |
108 | panic!("trying to access wrapped value in sticky container from incorrect thread." ); |
109 | } |
110 | } |
111 | |
112 | /// Consumes the `Sticky`, returning the wrapped value. |
113 | /// |
114 | /// # Panics |
115 | /// |
116 | /// Panics if called from a different thread than the one where the |
117 | /// original value was created. |
118 | pub fn into_inner(mut self) -> T { |
119 | self.assert_thread(); |
120 | unsafe { |
121 | let rv = self.unsafe_take_value(); |
122 | mem::forget(self); |
123 | rv |
124 | } |
125 | } |
126 | |
127 | unsafe fn unsafe_take_value(&mut self) -> T { |
128 | let ptr = registry::remove(self.item_id, self.thread_id) |
129 | .ptr |
130 | .cast::<T>(); |
131 | *Box::from_raw(ptr) |
132 | } |
133 | |
134 | /// Consumes the `Sticky`, returning the wrapped value if successful. |
135 | /// |
136 | /// The wrapped value is returned if this is called from the same thread |
137 | /// as the one where the original value was created, otherwise the |
138 | /// `Sticky` is returned as `Err(self)`. |
139 | pub fn try_into_inner(self) -> Result<T, Self> { |
140 | if self.is_valid() { |
141 | Ok(self.into_inner()) |
142 | } else { |
143 | Err(self) |
144 | } |
145 | } |
146 | |
147 | /// Immutably borrows the wrapped value. |
148 | /// |
149 | /// # Panics |
150 | /// |
151 | /// Panics if the calling thread is not the one that wrapped the value. |
152 | /// For a non-panicking variant, use [`try_get`](#method.try_get`). |
153 | pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T { |
154 | self.with_value(|value| unsafe { &*value }) |
155 | } |
156 | |
157 | /// Mutably borrows the wrapped value. |
158 | /// |
159 | /// # Panics |
160 | /// |
161 | /// Panics if the calling thread is not the one that wrapped the value. |
162 | /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`). |
163 | pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T { |
164 | self.with_value(|value| unsafe { &mut *value }) |
165 | } |
166 | |
167 | /// Tries to immutably borrow the wrapped value. |
168 | /// |
169 | /// Returns `None` if the calling thread is not the one that wrapped the value. |
170 | pub fn try_get<'stack>( |
171 | &'stack self, |
172 | _proof: &'stack StackToken, |
173 | ) -> Result<&'stack T, InvalidThreadAccess> { |
174 | if self.is_valid() { |
175 | Ok(self.with_value(|value| unsafe { &*value })) |
176 | } else { |
177 | Err(InvalidThreadAccess) |
178 | } |
179 | } |
180 | |
181 | /// Tries to mutably borrow the wrapped value. |
182 | /// |
183 | /// Returns `None` if the calling thread is not the one that wrapped the value. |
184 | pub fn try_get_mut<'stack>( |
185 | &'stack mut self, |
186 | _proof: &'stack StackToken, |
187 | ) -> Result<&'stack mut T, InvalidThreadAccess> { |
188 | if self.is_valid() { |
189 | Ok(self.with_value(|value| unsafe { &mut *value })) |
190 | } else { |
191 | Err(InvalidThreadAccess) |
192 | } |
193 | } |
194 | } |
195 | |
196 | impl<T> From<T> for Sticky<T> { |
197 | #[inline ] |
198 | fn from(t: T) -> Sticky<T> { |
199 | Sticky::new(t) |
200 | } |
201 | } |
202 | |
203 | impl<T: Clone> Clone for Sticky<T> { |
204 | #[inline ] |
205 | fn clone(&self) -> Sticky<T> { |
206 | crate::stack_token!(tok); |
207 | Sticky::new(self.get(tok).clone()) |
208 | } |
209 | } |
210 | |
211 | impl<T: Default> Default for Sticky<T> { |
212 | #[inline ] |
213 | fn default() -> Sticky<T> { |
214 | Sticky::new(T::default()) |
215 | } |
216 | } |
217 | |
218 | impl<T: PartialEq> PartialEq for Sticky<T> { |
219 | #[inline ] |
220 | fn eq(&self, other: &Sticky<T>) -> bool { |
221 | crate::stack_token!(tok); |
222 | *self.get(tok) == *other.get(tok) |
223 | } |
224 | } |
225 | |
226 | impl<T: Eq> Eq for Sticky<T> {} |
227 | |
228 | impl<T: PartialOrd> PartialOrd for Sticky<T> { |
229 | #[inline ] |
230 | fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> { |
231 | crate::stack_token!(tok); |
232 | self.get(tok).partial_cmp(other.get(tok)) |
233 | } |
234 | |
235 | #[inline ] |
236 | fn lt(&self, other: &Sticky<T>) -> bool { |
237 | crate::stack_token!(tok); |
238 | *self.get(tok) < *other.get(tok) |
239 | } |
240 | |
241 | #[inline ] |
242 | fn le(&self, other: &Sticky<T>) -> bool { |
243 | crate::stack_token!(tok); |
244 | *self.get(tok) <= *other.get(tok) |
245 | } |
246 | |
247 | #[inline ] |
248 | fn gt(&self, other: &Sticky<T>) -> bool { |
249 | crate::stack_token!(tok); |
250 | *self.get(tok) > *other.get(tok) |
251 | } |
252 | |
253 | #[inline ] |
254 | fn ge(&self, other: &Sticky<T>) -> bool { |
255 | crate::stack_token!(tok); |
256 | *self.get(tok) >= *other.get(tok) |
257 | } |
258 | } |
259 | |
260 | impl<T: Ord> Ord for Sticky<T> { |
261 | #[inline ] |
262 | fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering { |
263 | crate::stack_token!(tok); |
264 | self.get(tok).cmp(other.get(tok)) |
265 | } |
266 | } |
267 | |
268 | impl<T: fmt::Display> fmt::Display for Sticky<T> { |
269 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { |
270 | crate::stack_token!(tok); |
271 | fmt::Display::fmt(self.get(tok), f) |
272 | } |
273 | } |
274 | |
275 | impl<T: fmt::Debug> fmt::Debug for Sticky<T> { |
276 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { |
277 | crate::stack_token!(tok); |
278 | match self.try_get(tok) { |
279 | Ok(value) => f.debug_struct("Sticky" ).field("value" , value).finish(), |
280 | Err(..) => { |
281 | struct InvalidPlaceholder; |
282 | impl fmt::Debug for InvalidPlaceholder { |
283 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
284 | f.write_str("<invalid thread>" ) |
285 | } |
286 | } |
287 | |
288 | f.debug_struct("Sticky" ) |
289 | .field("value" , &InvalidPlaceholder) |
290 | .finish() |
291 | } |
292 | } |
293 | } |
294 | } |
295 | |
296 | // similar as for fragile ths type is sync because it only accesses TLS data |
297 | // which is thread local. There is nothing that needs to be synchronized. |
298 | unsafe impl<T> Sync for Sticky<T> {} |
299 | |
300 | // The entire point of this type is to be Send |
301 | unsafe impl<T> Send for Sticky<T> {} |
302 | |
303 | #[test] |
304 | fn test_basic() { |
305 | use std::thread; |
306 | let val = Sticky::new(true); |
307 | crate::stack_token!(tok); |
308 | assert_eq!(val.to_string(), "true" ); |
309 | assert_eq!(val.get(tok), &true); |
310 | assert!(val.try_get(tok).is_ok()); |
311 | thread::spawn(move || { |
312 | crate::stack_token!(tok); |
313 | assert!(val.try_get(tok).is_err()); |
314 | }) |
315 | .join() |
316 | .unwrap(); |
317 | } |
318 | |
319 | #[test] |
320 | fn test_mut() { |
321 | let mut val = Sticky::new(true); |
322 | crate::stack_token!(tok); |
323 | *val.get_mut(tok) = false; |
324 | assert_eq!(val.to_string(), "false" ); |
325 | assert_eq!(val.get(tok), &false); |
326 | } |
327 | |
328 | #[test] |
329 | #[should_panic ] |
330 | fn test_access_other_thread() { |
331 | use std::thread; |
332 | let val = Sticky::new(true); |
333 | thread::spawn(move || { |
334 | crate::stack_token!(tok); |
335 | val.get(tok); |
336 | }) |
337 | .join() |
338 | .unwrap(); |
339 | } |
340 | |
341 | #[test] |
342 | fn test_drop_same_thread() { |
343 | use std::sync::atomic::{AtomicBool, Ordering}; |
344 | use std::sync::Arc; |
345 | let was_called = Arc::new(AtomicBool::new(false)); |
346 | struct X(Arc<AtomicBool>); |
347 | impl Drop for X { |
348 | fn drop(&mut self) { |
349 | self.0.store(true, Ordering::SeqCst); |
350 | } |
351 | } |
352 | let val = Sticky::new(X(was_called.clone())); |
353 | mem::drop(val); |
354 | assert!(was_called.load(Ordering::SeqCst)); |
355 | } |
356 | |
357 | #[test] |
358 | fn test_noop_drop_elsewhere() { |
359 | use std::sync::atomic::{AtomicBool, Ordering}; |
360 | use std::sync::Arc; |
361 | use std::thread; |
362 | |
363 | let was_called = Arc::new(AtomicBool::new(false)); |
364 | |
365 | { |
366 | let was_called = was_called.clone(); |
367 | thread::spawn(move || { |
368 | struct X(Arc<AtomicBool>); |
369 | impl Drop for X { |
370 | fn drop(&mut self) { |
371 | self.0.store(true, Ordering::SeqCst); |
372 | } |
373 | } |
374 | |
375 | let val = Sticky::new(X(was_called.clone())); |
376 | assert!(thread::spawn(move || { |
377 | // moves it here but do not deallocate |
378 | crate::stack_token!(tok); |
379 | val.try_get(tok).ok(); |
380 | }) |
381 | .join() |
382 | .is_ok()); |
383 | |
384 | assert!(!was_called.load(Ordering::SeqCst)); |
385 | }) |
386 | .join() |
387 | .unwrap(); |
388 | } |
389 | |
390 | assert!(was_called.load(Ordering::SeqCst)); |
391 | } |
392 | |
393 | #[test] |
394 | fn test_rc_sending() { |
395 | use std::rc::Rc; |
396 | use std::thread; |
397 | let val = Sticky::new(Rc::new(true)); |
398 | thread::spawn(move || { |
399 | crate::stack_token!(tok); |
400 | assert!(val.try_get(tok).is_err()); |
401 | }) |
402 | .join() |
403 | .unwrap(); |
404 | } |
405 | |
406 | #[test] |
407 | fn test_two_stickies() { |
408 | struct Wat; |
409 | |
410 | impl Drop for Wat { |
411 | fn drop(&mut self) { |
412 | // do nothing |
413 | } |
414 | } |
415 | |
416 | let s1 = Sticky::new(Wat); |
417 | let s2 = Sticky::new(Wat); |
418 | |
419 | // make sure all is well |
420 | |
421 | drop(s1); |
422 | drop(s2); |
423 | } |
424 | |