| 1 | // There's a lot of scary concurrent code in this module, but it is copied from |
| 2 | // `std::sync::Once` with two changes: |
| 3 | // * no poisoning |
| 4 | // * init function can fail |
| 5 | |
| 6 | use std::{ |
| 7 | cell::{Cell, UnsafeCell}, |
| 8 | panic::{RefUnwindSafe, UnwindSafe}, |
| 9 | sync::atomic::{AtomicBool, AtomicPtr, Ordering}, |
| 10 | thread::{self, Thread}, |
| 11 | }; |
| 12 | |
| 13 | #[derive (Debug)] |
| 14 | pub(crate) struct OnceCell<T> { |
| 15 | // This `queue` field is the core of the implementation. It encodes two |
| 16 | // pieces of information: |
| 17 | // |
| 18 | // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) |
| 19 | // * Linked list of threads waiting for the current cell. |
| 20 | // |
| 21 | // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states |
| 22 | // allow waiters. |
| 23 | queue: AtomicPtr<Waiter>, |
| 24 | value: UnsafeCell<Option<T>>, |
| 25 | } |
| 26 | |
| 27 | // Why do we need `T: Send`? |
| 28 | // Thread A creates a `OnceCell` and shares it with |
| 29 | // scoped thread B, which fills the cell, which is |
| 30 | // then destroyed by A. That is, destructor observes |
| 31 | // a sent value. |
| 32 | unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} |
| 33 | unsafe impl<T: Send> Send for OnceCell<T> {} |
| 34 | |
| 35 | impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} |
| 36 | impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} |
| 37 | |
| 38 | impl<T> OnceCell<T> { |
| 39 | pub(crate) const fn new() -> OnceCell<T> { |
| 40 | OnceCell { queue: AtomicPtr::new(INCOMPLETE_PTR), value: UnsafeCell::new(None) } |
| 41 | } |
| 42 | |
| 43 | pub(crate) const fn with_value(value: T) -> OnceCell<T> { |
| 44 | OnceCell { queue: AtomicPtr::new(COMPLETE_PTR), value: UnsafeCell::new(Some(value)) } |
| 45 | } |
| 46 | |
| 47 | /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). |
| 48 | #[inline ] |
| 49 | pub(crate) fn is_initialized(&self) -> bool { |
| 50 | // An `Acquire` load is enough because that makes all the initialization |
| 51 | // operations visible to us, and, this being a fast path, weaker |
| 52 | // ordering helps with performance. This `Acquire` synchronizes with |
| 53 | // `SeqCst` operations on the slow path. |
| 54 | self.queue.load(Ordering::Acquire) == COMPLETE_PTR |
| 55 | } |
| 56 | |
| 57 | /// Safety: synchronizes with store to value via SeqCst read from state, |
| 58 | /// writes value only once because we never get to INCOMPLETE state after a |
| 59 | /// successful write. |
| 60 | #[cold ] |
| 61 | pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> |
| 62 | where |
| 63 | F: FnOnce() -> Result<T, E>, |
| 64 | { |
| 65 | let mut f = Some(f); |
| 66 | let mut res: Result<(), E> = Ok(()); |
| 67 | let slot: *mut Option<T> = self.value.get(); |
| 68 | initialize_or_wait( |
| 69 | &self.queue, |
| 70 | Some(&mut || { |
| 71 | let f = unsafe { f.take().unwrap_unchecked() }; |
| 72 | match f() { |
| 73 | Ok(value) => { |
| 74 | unsafe { *slot = Some(value) }; |
| 75 | true |
| 76 | } |
| 77 | Err(err) => { |
| 78 | res = Err(err); |
| 79 | false |
| 80 | } |
| 81 | } |
| 82 | }), |
| 83 | ); |
| 84 | res |
| 85 | } |
| 86 | |
| 87 | #[cold ] |
| 88 | pub(crate) fn wait(&self) { |
| 89 | initialize_or_wait(&self.queue, None); |
| 90 | } |
| 91 | |
| 92 | /// Get the reference to the underlying value, without checking if the cell |
| 93 | /// is initialized. |
| 94 | /// |
| 95 | /// # Safety |
| 96 | /// |
| 97 | /// Caller must ensure that the cell is in initialized state, and that |
| 98 | /// the contents are acquired by (synchronized to) this thread. |
| 99 | pub(crate) unsafe fn get_unchecked(&self) -> &T { |
| 100 | debug_assert!(self.is_initialized()); |
| 101 | let slot = &*self.value.get(); |
| 102 | slot.as_ref().unwrap_unchecked() |
| 103 | } |
| 104 | |
| 105 | /// Gets the mutable reference to the underlying value. |
| 106 | /// Returns `None` if the cell is empty. |
| 107 | pub(crate) fn get_mut(&mut self) -> Option<&mut T> { |
| 108 | // Safe b/c we have a unique access. |
| 109 | unsafe { &mut *self.value.get() }.as_mut() |
| 110 | } |
| 111 | |
| 112 | /// Consumes this `OnceCell`, returning the wrapped value. |
| 113 | /// Returns `None` if the cell was empty. |
| 114 | #[inline ] |
| 115 | pub(crate) fn into_inner(self) -> Option<T> { |
| 116 | // Because `into_inner` takes `self` by value, the compiler statically |
| 117 | // verifies that it is not currently borrowed. |
| 118 | // So, it is safe to move out `Option<T>`. |
| 119 | self.value.into_inner() |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in |
| 124 | // the OnceCell structure. |
| 125 | const INCOMPLETE: usize = 0x0; |
| 126 | const RUNNING: usize = 0x1; |
| 127 | const COMPLETE: usize = 0x2; |
| 128 | const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter; |
| 129 | const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter; |
| 130 | |
| 131 | // Mask to learn about the state. All other bits are the queue of waiters if |
| 132 | // this is in the RUNNING state. |
| 133 | const STATE_MASK: usize = 0x3; |
| 134 | |
| 135 | /// Representation of a node in the linked list of waiters in the RUNNING state. |
| 136 | /// A waiters is stored on the stack of the waiting threads. |
| 137 | #[repr (align(4))] // Ensure the two lower bits are free to use as state bits. |
| 138 | struct Waiter { |
| 139 | thread: Cell<Option<Thread>>, |
| 140 | signaled: AtomicBool, |
| 141 | next: *mut Waiter, |
| 142 | } |
| 143 | |
| 144 | /// Drains and notifies the queue of waiters on drop. |
| 145 | struct Guard<'a> { |
| 146 | queue: &'a AtomicPtr<Waiter>, |
| 147 | new_queue: *mut Waiter, |
| 148 | } |
| 149 | |
| 150 | impl Drop for Guard<'_> { |
| 151 | fn drop(&mut self) { |
| 152 | let queue: *mut Waiter = self.queue.swap(self.new_queue, order:Ordering::AcqRel); |
| 153 | |
| 154 | let state: usize = strict::addr(ptr:queue) & STATE_MASK; |
| 155 | assert_eq!(state, RUNNING); |
| 156 | |
| 157 | unsafe { |
| 158 | let mut waiter: *mut Waiter = strict::map_addr(ptr:queue, |q: usize| q & !STATE_MASK); |
| 159 | while !waiter.is_null() { |
| 160 | let next: *mut Waiter = (*waiter).next; |
| 161 | let thread: Thread = (*waiter).thread.take().unwrap(); |
| 162 | (*waiter).signaled.store(val:true, order:Ordering::Release); |
| 163 | waiter = next; |
| 164 | thread.unpark(); |
| 165 | } |
| 166 | } |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | // Corresponds to `std::sync::Once::call_inner`. |
| 171 | // |
| 172 | // Originally copied from std, but since modified to remove poisoning and to |
| 173 | // support wait. |
| 174 | // |
| 175 | // Note: this is intentionally monomorphic |
| 176 | #[inline (never)] |
| 177 | fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) { |
| 178 | let mut curr_queue = queue.load(Ordering::Acquire); |
| 179 | |
| 180 | loop { |
| 181 | let curr_state = strict::addr(curr_queue) & STATE_MASK; |
| 182 | match (curr_state, &mut init) { |
| 183 | (COMPLETE, _) => return, |
| 184 | (INCOMPLETE, Some(init)) => { |
| 185 | let exchange = queue.compare_exchange( |
| 186 | curr_queue, |
| 187 | strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING), |
| 188 | Ordering::Acquire, |
| 189 | Ordering::Acquire, |
| 190 | ); |
| 191 | if let Err(new_queue) = exchange { |
| 192 | curr_queue = new_queue; |
| 193 | continue; |
| 194 | } |
| 195 | let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR }; |
| 196 | if init() { |
| 197 | guard.new_queue = COMPLETE_PTR; |
| 198 | } |
| 199 | return; |
| 200 | } |
| 201 | (INCOMPLETE, None) | (RUNNING, _) => { |
| 202 | wait(queue, curr_queue); |
| 203 | curr_queue = queue.load(Ordering::Acquire); |
| 204 | } |
| 205 | _ => debug_assert!(false), |
| 206 | } |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) { |
| 211 | let curr_state = strict::addr(curr_queue) & STATE_MASK; |
| 212 | loop { |
| 213 | let node = Waiter { |
| 214 | thread: Cell::new(Some(thread::current())), |
| 215 | signaled: AtomicBool::new(false), |
| 216 | next: strict::map_addr(curr_queue, |q| q & !STATE_MASK), |
| 217 | }; |
| 218 | let me = &node as *const Waiter as *mut Waiter; |
| 219 | |
| 220 | let exchange = queue.compare_exchange( |
| 221 | curr_queue, |
| 222 | strict::map_addr(me, |q| q | curr_state), |
| 223 | Ordering::Release, |
| 224 | Ordering::Relaxed, |
| 225 | ); |
| 226 | if let Err(new_queue) = exchange { |
| 227 | if strict::addr(new_queue) & STATE_MASK != curr_state { |
| 228 | return; |
| 229 | } |
| 230 | curr_queue = new_queue; |
| 231 | continue; |
| 232 | } |
| 233 | |
| 234 | while !node.signaled.load(Ordering::Acquire) { |
| 235 | thread::park(); |
| 236 | } |
| 237 | break; |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | // Polyfill of strict provenance from https://crates.io/crates/sptr. |
| 242 | // |
| 243 | // Use free-standing function rather than a trait to keep things simple and |
| 244 | // avoid any potential conflicts with future stabile std API. |
| 245 | mod strict { |
| 246 | #[must_use ] |
| 247 | #[inline ] |
| 248 | pub(crate) fn addr<T>(ptr: *mut T) -> usize |
| 249 | where |
| 250 | T: Sized, |
| 251 | { |
| 252 | // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. |
| 253 | // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the |
| 254 | // provenance). |
| 255 | unsafe { core::mem::transmute(ptr) } |
| 256 | } |
| 257 | |
| 258 | #[must_use ] |
| 259 | #[inline ] |
| 260 | pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T |
| 261 | where |
| 262 | T: Sized, |
| 263 | { |
| 264 | // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. |
| 265 | // |
| 266 | // In the mean-time, this operation is defined to be "as if" it was |
| 267 | // a wrapping_offset, so we can emulate it as such. This should properly |
| 268 | // restore pointer provenance even under today's compiler. |
| 269 | let self_addr = self::addr(ptr) as isize; |
| 270 | let dest_addr = addr as isize; |
| 271 | let offset = dest_addr.wrapping_sub(self_addr); |
| 272 | |
| 273 | // This is the canonical desugarring of this operation, |
| 274 | // but `pointer::cast` was only stabilized in 1.38. |
| 275 | // self.cast::<u8>().wrapping_offset(offset).cast::<T>() |
| 276 | (ptr as *mut u8).wrapping_offset(offset) as *mut T |
| 277 | } |
| 278 | |
| 279 | #[must_use ] |
| 280 | #[inline ] |
| 281 | pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T |
| 282 | where |
| 283 | T: Sized, |
| 284 | { |
| 285 | self::with_addr(ptr, f(addr(ptr))) |
| 286 | } |
| 287 | } |
| 288 | |
| 289 | // These test are snatched from std as well. |
| 290 | #[cfg (test)] |
| 291 | mod tests { |
| 292 | use std::panic; |
| 293 | use std::{sync::mpsc::channel, thread}; |
| 294 | |
| 295 | use super::OnceCell; |
| 296 | |
| 297 | impl<T> OnceCell<T> { |
| 298 | fn init(&self, f: impl FnOnce() -> T) { |
| 299 | enum Void {} |
| 300 | let _ = self.initialize(|| Ok::<T, Void>(f())); |
| 301 | } |
| 302 | } |
| 303 | |
| 304 | #[test ] |
| 305 | fn smoke_once() { |
| 306 | static O: OnceCell<()> = OnceCell::new(); |
| 307 | let mut a = 0; |
| 308 | O.init(|| a += 1); |
| 309 | assert_eq!(a, 1); |
| 310 | O.init(|| a += 1); |
| 311 | assert_eq!(a, 1); |
| 312 | } |
| 313 | |
| 314 | #[test ] |
| 315 | fn stampede_once() { |
| 316 | static O: OnceCell<()> = OnceCell::new(); |
| 317 | static mut RUN: bool = false; |
| 318 | |
| 319 | let (tx, rx) = channel(); |
| 320 | for _ in 0..10 { |
| 321 | let tx = tx.clone(); |
| 322 | thread::spawn(move || { |
| 323 | for _ in 0..4 { |
| 324 | thread::yield_now() |
| 325 | } |
| 326 | unsafe { |
| 327 | O.init(|| { |
| 328 | assert!(!RUN); |
| 329 | RUN = true; |
| 330 | }); |
| 331 | assert!(RUN); |
| 332 | } |
| 333 | tx.send(()).unwrap(); |
| 334 | }); |
| 335 | } |
| 336 | |
| 337 | unsafe { |
| 338 | O.init(|| { |
| 339 | assert!(!RUN); |
| 340 | RUN = true; |
| 341 | }); |
| 342 | assert!(RUN); |
| 343 | } |
| 344 | |
| 345 | for _ in 0..10 { |
| 346 | rx.recv().unwrap(); |
| 347 | } |
| 348 | } |
| 349 | |
| 350 | #[test ] |
| 351 | fn poison_bad() { |
| 352 | static O: OnceCell<()> = OnceCell::new(); |
| 353 | |
| 354 | // poison the once |
| 355 | let t = panic::catch_unwind(|| { |
| 356 | O.init(|| panic!()); |
| 357 | }); |
| 358 | assert!(t.is_err()); |
| 359 | |
| 360 | // we can subvert poisoning, however |
| 361 | let mut called = false; |
| 362 | O.init(|| { |
| 363 | called = true; |
| 364 | }); |
| 365 | assert!(called); |
| 366 | |
| 367 | // once any success happens, we stop propagating the poison |
| 368 | O.init(|| {}); |
| 369 | } |
| 370 | |
| 371 | #[test ] |
| 372 | fn wait_for_force_to_finish() { |
| 373 | static O: OnceCell<()> = OnceCell::new(); |
| 374 | |
| 375 | // poison the once |
| 376 | let t = panic::catch_unwind(|| { |
| 377 | O.init(|| panic!()); |
| 378 | }); |
| 379 | assert!(t.is_err()); |
| 380 | |
| 381 | // make sure someone's waiting inside the once via a force |
| 382 | let (tx1, rx1) = channel(); |
| 383 | let (tx2, rx2) = channel(); |
| 384 | let t1 = thread::spawn(move || { |
| 385 | O.init(|| { |
| 386 | tx1.send(()).unwrap(); |
| 387 | rx2.recv().unwrap(); |
| 388 | }); |
| 389 | }); |
| 390 | |
| 391 | rx1.recv().unwrap(); |
| 392 | |
| 393 | // put another waiter on the once |
| 394 | let t2 = thread::spawn(|| { |
| 395 | let mut called = false; |
| 396 | O.init(|| { |
| 397 | called = true; |
| 398 | }); |
| 399 | assert!(!called); |
| 400 | }); |
| 401 | |
| 402 | tx2.send(()).unwrap(); |
| 403 | |
| 404 | assert!(t1.join().is_ok()); |
| 405 | assert!(t2.join().is_ok()); |
| 406 | } |
| 407 | |
| 408 | #[test ] |
| 409 | #[cfg (target_pointer_width = "64" )] |
| 410 | fn test_size() { |
| 411 | use std::mem::size_of; |
| 412 | |
| 413 | assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); |
| 414 | } |
| 415 | } |
| 416 | |