1// There's a lot of scary concurrent code in this module, but it is copied from
2// `std::sync::Once` with two changes:
3// * no poisoning
4// * init function can fail
5
6use std::{
7 cell::{Cell, UnsafeCell},
8 panic::{RefUnwindSafe, UnwindSafe},
9 sync::atomic::{AtomicBool, AtomicPtr, Ordering},
10 thread::{self, Thread},
11};
12
13#[derive(Debug)]
14pub(crate) struct OnceCell<T> {
15 // This `queue` field is the core of the implementation. It encodes two
16 // pieces of information:
17 //
18 // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
19 // * Linked list of threads waiting for the current cell.
20 //
21 // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
22 // allow waiters.
23 queue: AtomicPtr<Waiter>,
24 value: UnsafeCell<Option<T>>,
25}
26
27// Why do we need `T: Send`?
28// Thread A creates a `OnceCell` and shares it with
29// scoped thread B, which fills the cell, which is
30// then destroyed by A. That is, destructor observes
31// a sent value.
32unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
33unsafe impl<T: Send> Send for OnceCell<T> {}
34
35impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
36impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
37
38impl<T> OnceCell<T> {
39 pub(crate) const fn new() -> OnceCell<T> {
40 OnceCell { queue: AtomicPtr::new(INCOMPLETE_PTR), value: UnsafeCell::new(None) }
41 }
42
43 pub(crate) const fn with_value(value: T) -> OnceCell<T> {
44 OnceCell { queue: AtomicPtr::new(COMPLETE_PTR), value: UnsafeCell::new(Some(value)) }
45 }
46
47 /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
48 #[inline]
49 pub(crate) fn is_initialized(&self) -> bool {
50 // An `Acquire` load is enough because that makes all the initialization
51 // operations visible to us, and, this being a fast path, weaker
52 // ordering helps with performance. This `Acquire` synchronizes with
53 // `SeqCst` operations on the slow path.
54 self.queue.load(Ordering::Acquire) == COMPLETE_PTR
55 }
56
57 /// Safety: synchronizes with store to value via SeqCst read from state,
58 /// writes value only once because we never get to INCOMPLETE state after a
59 /// successful write.
60 #[cold]
61 pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
62 where
63 F: FnOnce() -> Result<T, E>,
64 {
65 let mut f = Some(f);
66 let mut res: Result<(), E> = Ok(());
67 let slot: *mut Option<T> = self.value.get();
68 initialize_or_wait(
69 &self.queue,
70 Some(&mut || {
71 let f = unsafe { f.take().unwrap_unchecked() };
72 match f() {
73 Ok(value) => {
74 unsafe { *slot = Some(value) };
75 true
76 }
77 Err(err) => {
78 res = Err(err);
79 false
80 }
81 }
82 }),
83 );
84 res
85 }
86
87 #[cold]
88 pub(crate) fn wait(&self) {
89 initialize_or_wait(&self.queue, None);
90 }
91
92 /// Get the reference to the underlying value, without checking if the cell
93 /// is initialized.
94 ///
95 /// # Safety
96 ///
97 /// Caller must ensure that the cell is in initialized state, and that
98 /// the contents are acquired by (synchronized to) this thread.
99 pub(crate) unsafe fn get_unchecked(&self) -> &T {
100 debug_assert!(self.is_initialized());
101 let slot = &*self.value.get();
102 slot.as_ref().unwrap_unchecked()
103 }
104
105 /// Gets the mutable reference to the underlying value.
106 /// Returns `None` if the cell is empty.
107 pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
108 // Safe b/c we have a unique access.
109 unsafe { &mut *self.value.get() }.as_mut()
110 }
111
112 /// Consumes this `OnceCell`, returning the wrapped value.
113 /// Returns `None` if the cell was empty.
114 #[inline]
115 pub(crate) fn into_inner(self) -> Option<T> {
116 // Because `into_inner` takes `self` by value, the compiler statically
117 // verifies that it is not currently borrowed.
118 // So, it is safe to move out `Option<T>`.
119 self.value.into_inner()
120 }
121}
122
123// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
124// the OnceCell structure.
125const INCOMPLETE: usize = 0x0;
126const RUNNING: usize = 0x1;
127const COMPLETE: usize = 0x2;
128const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
129const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
130
131// Mask to learn about the state. All other bits are the queue of waiters if
132// this is in the RUNNING state.
133const STATE_MASK: usize = 0x3;
134
135/// Representation of a node in the linked list of waiters in the RUNNING state.
136/// A waiters is stored on the stack of the waiting threads.
137#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
138struct Waiter {
139 thread: Cell<Option<Thread>>,
140 signaled: AtomicBool,
141 next: *mut Waiter,
142}
143
144/// Drains and notifies the queue of waiters on drop.
145struct Guard<'a> {
146 queue: &'a AtomicPtr<Waiter>,
147 new_queue: *mut Waiter,
148}
149
150impl Drop for Guard<'_> {
151 fn drop(&mut self) {
152 let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
153
154 let state = strict::addr(queue) & STATE_MASK;
155 assert_eq!(state, RUNNING);
156
157 unsafe {
158 let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
159 while !waiter.is_null() {
160 let next = (*waiter).next;
161 let thread = (*waiter).thread.take().unwrap();
162 (*waiter).signaled.store(true, Ordering::Release);
163 waiter = next;
164 thread.unpark();
165 }
166 }
167 }
168}
169
170// Corresponds to `std::sync::Once::call_inner`.
171//
172// Originally copied from std, but since modified to remove poisoning and to
173// support wait.
174//
175// Note: this is intentionally monomorphic
176#[inline(never)]
177fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
178 let mut curr_queue = queue.load(Ordering::Acquire);
179
180 loop {
181 let curr_state = strict::addr(curr_queue) & STATE_MASK;
182 match (curr_state, &mut init) {
183 (COMPLETE, _) => return,
184 (INCOMPLETE, Some(init)) => {
185 let exchange = queue.compare_exchange(
186 curr_queue,
187 strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
188 Ordering::Acquire,
189 Ordering::Acquire,
190 );
191 if let Err(new_queue) = exchange {
192 curr_queue = new_queue;
193 continue;
194 }
195 let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
196 if init() {
197 guard.new_queue = COMPLETE_PTR;
198 }
199 return;
200 }
201 (INCOMPLETE, None) | (RUNNING, _) => {
202 wait(queue, curr_queue);
203 curr_queue = queue.load(Ordering::Acquire);
204 }
205 _ => debug_assert!(false),
206 }
207 }
208}
209
210fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
211 let curr_state = strict::addr(curr_queue) & STATE_MASK;
212 loop {
213 let node = Waiter {
214 thread: Cell::new(Some(thread::current())),
215 signaled: AtomicBool::new(false),
216 next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
217 };
218 let me = &node as *const Waiter as *mut Waiter;
219
220 let exchange = queue.compare_exchange(
221 curr_queue,
222 strict::map_addr(me, |q| q | curr_state),
223 Ordering::Release,
224 Ordering::Relaxed,
225 );
226 if let Err(new_queue) = exchange {
227 if strict::addr(new_queue) & STATE_MASK != curr_state {
228 return;
229 }
230 curr_queue = new_queue;
231 continue;
232 }
233
234 while !node.signaled.load(Ordering::Acquire) {
235 thread::park();
236 }
237 break;
238 }
239}
240
241// Polyfill of strict provenance from https://crates.io/crates/sptr.
242//
243// Use free-standing function rather than a trait to keep things simple and
244// avoid any potential conflicts with future stabile std API.
245mod strict {
246 #[must_use]
247 #[inline]
248 pub(crate) fn addr<T>(ptr: *mut T) -> usize
249 where
250 T: Sized,
251 {
252 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
253 // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
254 // provenance).
255 unsafe { core::mem::transmute(ptr) }
256 }
257
258 #[must_use]
259 #[inline]
260 pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
261 where
262 T: Sized,
263 {
264 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
265 //
266 // In the mean-time, this operation is defined to be "as if" it was
267 // a wrapping_offset, so we can emulate it as such. This should properly
268 // restore pointer provenance even under today's compiler.
269 let self_addr = self::addr(ptr) as isize;
270 let dest_addr = addr as isize;
271 let offset = dest_addr.wrapping_sub(self_addr);
272
273 // This is the canonical desugarring of this operation,
274 // but `pointer::cast` was only stabilized in 1.38.
275 // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
276 (ptr as *mut u8).wrapping_offset(offset) as *mut T
277 }
278
279 #[must_use]
280 #[inline]
281 pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
282 where
283 T: Sized,
284 {
285 self::with_addr(ptr, f(addr(ptr)))
286 }
287}
288
289// These test are snatched from std as well.
290#[cfg(test)]
291mod tests {
292 use std::panic;
293 use std::{sync::mpsc::channel, thread};
294
295 use super::OnceCell;
296
297 impl<T> OnceCell<T> {
298 fn init(&self, f: impl FnOnce() -> T) {
299 enum Void {}
300 let _ = self.initialize(|| Ok::<T, Void>(f()));
301 }
302 }
303
304 #[test]
305 fn smoke_once() {
306 static O: OnceCell<()> = OnceCell::new();
307 let mut a = 0;
308 O.init(|| a += 1);
309 assert_eq!(a, 1);
310 O.init(|| a += 1);
311 assert_eq!(a, 1);
312 }
313
314 #[test]
315 fn stampede_once() {
316 static O: OnceCell<()> = OnceCell::new();
317 static mut RUN: bool = false;
318
319 let (tx, rx) = channel();
320 for _ in 0..10 {
321 let tx = tx.clone();
322 thread::spawn(move || {
323 for _ in 0..4 {
324 thread::yield_now()
325 }
326 unsafe {
327 O.init(|| {
328 assert!(!RUN);
329 RUN = true;
330 });
331 assert!(RUN);
332 }
333 tx.send(()).unwrap();
334 });
335 }
336
337 unsafe {
338 O.init(|| {
339 assert!(!RUN);
340 RUN = true;
341 });
342 assert!(RUN);
343 }
344
345 for _ in 0..10 {
346 rx.recv().unwrap();
347 }
348 }
349
350 #[test]
351 fn poison_bad() {
352 static O: OnceCell<()> = OnceCell::new();
353
354 // poison the once
355 let t = panic::catch_unwind(|| {
356 O.init(|| panic!());
357 });
358 assert!(t.is_err());
359
360 // we can subvert poisoning, however
361 let mut called = false;
362 O.init(|| {
363 called = true;
364 });
365 assert!(called);
366
367 // once any success happens, we stop propagating the poison
368 O.init(|| {});
369 }
370
371 #[test]
372 fn wait_for_force_to_finish() {
373 static O: OnceCell<()> = OnceCell::new();
374
375 // poison the once
376 let t = panic::catch_unwind(|| {
377 O.init(|| panic!());
378 });
379 assert!(t.is_err());
380
381 // make sure someone's waiting inside the once via a force
382 let (tx1, rx1) = channel();
383 let (tx2, rx2) = channel();
384 let t1 = thread::spawn(move || {
385 O.init(|| {
386 tx1.send(()).unwrap();
387 rx2.recv().unwrap();
388 });
389 });
390
391 rx1.recv().unwrap();
392
393 // put another waiter on the once
394 let t2 = thread::spawn(|| {
395 let mut called = false;
396 O.init(|| {
397 called = true;
398 });
399 assert!(!called);
400 });
401
402 tx2.send(()).unwrap();
403
404 assert!(t1.join().is_ok());
405 assert!(t2.join().is_ok());
406 }
407
408 #[test]
409 #[cfg(target_pointer_width = "64")]
410 fn test_size() {
411 use std::mem::size_of;
412
413 assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
414 }
415}
416