1 | // There's a lot of scary concurrent code in this module, but it is copied from |
2 | // `std::sync::Once` with two changes: |
3 | // * no poisoning |
4 | // * init function can fail |
5 | |
6 | use std::{ |
7 | cell::{Cell, UnsafeCell}, |
8 | panic::{RefUnwindSafe, UnwindSafe}, |
9 | sync::atomic::{AtomicBool, AtomicPtr, Ordering}, |
10 | thread::{self, Thread}, |
11 | }; |
12 | |
13 | #[derive (Debug)] |
14 | pub(crate) struct OnceCell<T> { |
15 | // This `queue` field is the core of the implementation. It encodes two |
16 | // pieces of information: |
17 | // |
18 | // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) |
19 | // * Linked list of threads waiting for the current cell. |
20 | // |
21 | // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states |
22 | // allow waiters. |
23 | queue: AtomicPtr<Waiter>, |
24 | value: UnsafeCell<Option<T>>, |
25 | } |
26 | |
27 | // Why do we need `T: Send`? |
28 | // Thread A creates a `OnceCell` and shares it with |
29 | // scoped thread B, which fills the cell, which is |
30 | // then destroyed by A. That is, destructor observes |
31 | // a sent value. |
32 | unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} |
33 | unsafe impl<T: Send> Send for OnceCell<T> {} |
34 | |
35 | impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} |
36 | impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} |
37 | |
38 | impl<T> OnceCell<T> { |
39 | pub(crate) const fn new() -> OnceCell<T> { |
40 | OnceCell { queue: AtomicPtr::new(INCOMPLETE_PTR), value: UnsafeCell::new(None) } |
41 | } |
42 | |
43 | pub(crate) const fn with_value(value: T) -> OnceCell<T> { |
44 | OnceCell { queue: AtomicPtr::new(COMPLETE_PTR), value: UnsafeCell::new(Some(value)) } |
45 | } |
46 | |
47 | /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). |
48 | #[inline ] |
49 | pub(crate) fn is_initialized(&self) -> bool { |
50 | // An `Acquire` load is enough because that makes all the initialization |
51 | // operations visible to us, and, this being a fast path, weaker |
52 | // ordering helps with performance. This `Acquire` synchronizes with |
53 | // `SeqCst` operations on the slow path. |
54 | self.queue.load(Ordering::Acquire) == COMPLETE_PTR |
55 | } |
56 | |
57 | /// Safety: synchronizes with store to value via SeqCst read from state, |
58 | /// writes value only once because we never get to INCOMPLETE state after a |
59 | /// successful write. |
60 | #[cold ] |
61 | pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> |
62 | where |
63 | F: FnOnce() -> Result<T, E>, |
64 | { |
65 | let mut f = Some(f); |
66 | let mut res: Result<(), E> = Ok(()); |
67 | let slot: *mut Option<T> = self.value.get(); |
68 | initialize_or_wait( |
69 | &self.queue, |
70 | Some(&mut || { |
71 | let f = unsafe { f.take().unwrap_unchecked() }; |
72 | match f() { |
73 | Ok(value) => { |
74 | unsafe { *slot = Some(value) }; |
75 | true |
76 | } |
77 | Err(err) => { |
78 | res = Err(err); |
79 | false |
80 | } |
81 | } |
82 | }), |
83 | ); |
84 | res |
85 | } |
86 | |
87 | #[cold ] |
88 | pub(crate) fn wait(&self) { |
89 | initialize_or_wait(&self.queue, None); |
90 | } |
91 | |
92 | /// Get the reference to the underlying value, without checking if the cell |
93 | /// is initialized. |
94 | /// |
95 | /// # Safety |
96 | /// |
97 | /// Caller must ensure that the cell is in initialized state, and that |
98 | /// the contents are acquired by (synchronized to) this thread. |
99 | pub(crate) unsafe fn get_unchecked(&self) -> &T { |
100 | debug_assert!(self.is_initialized()); |
101 | let slot = &*self.value.get(); |
102 | slot.as_ref().unwrap_unchecked() |
103 | } |
104 | |
105 | /// Gets the mutable reference to the underlying value. |
106 | /// Returns `None` if the cell is empty. |
107 | pub(crate) fn get_mut(&mut self) -> Option<&mut T> { |
108 | // Safe b/c we have a unique access. |
109 | unsafe { &mut *self.value.get() }.as_mut() |
110 | } |
111 | |
112 | /// Consumes this `OnceCell`, returning the wrapped value. |
113 | /// Returns `None` if the cell was empty. |
114 | #[inline ] |
115 | pub(crate) fn into_inner(self) -> Option<T> { |
116 | // Because `into_inner` takes `self` by value, the compiler statically |
117 | // verifies that it is not currently borrowed. |
118 | // So, it is safe to move out `Option<T>`. |
119 | self.value.into_inner() |
120 | } |
121 | } |
122 | |
123 | // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in |
124 | // the OnceCell structure. |
125 | const INCOMPLETE: usize = 0x0; |
126 | const RUNNING: usize = 0x1; |
127 | const COMPLETE: usize = 0x2; |
128 | const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter; |
129 | const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter; |
130 | |
131 | // Mask to learn about the state. All other bits are the queue of waiters if |
132 | // this is in the RUNNING state. |
133 | const STATE_MASK: usize = 0x3; |
134 | |
135 | /// Representation of a node in the linked list of waiters in the RUNNING state. |
136 | /// A waiters is stored on the stack of the waiting threads. |
137 | #[repr (align(4))] // Ensure the two lower bits are free to use as state bits. |
138 | struct Waiter { |
139 | thread: Cell<Option<Thread>>, |
140 | signaled: AtomicBool, |
141 | next: *mut Waiter, |
142 | } |
143 | |
144 | /// Drains and notifies the queue of waiters on drop. |
145 | struct Guard<'a> { |
146 | queue: &'a AtomicPtr<Waiter>, |
147 | new_queue: *mut Waiter, |
148 | } |
149 | |
150 | impl Drop for Guard<'_> { |
151 | fn drop(&mut self) { |
152 | let queue: *mut Waiter = self.queue.swap(self.new_queue, order:Ordering::AcqRel); |
153 | |
154 | let state: usize = strict::addr(ptr:queue) & STATE_MASK; |
155 | assert_eq!(state, RUNNING); |
156 | |
157 | unsafe { |
158 | let mut waiter: *mut Waiter = strict::map_addr(ptr:queue, |q: usize| q & !STATE_MASK); |
159 | while !waiter.is_null() { |
160 | let next: *mut Waiter = (*waiter).next; |
161 | let thread: Thread = (*waiter).thread.take().unwrap(); |
162 | (*waiter).signaled.store(val:true, order:Ordering::Release); |
163 | waiter = next; |
164 | thread.unpark(); |
165 | } |
166 | } |
167 | } |
168 | } |
169 | |
170 | // Corresponds to `std::sync::Once::call_inner`. |
171 | // |
172 | // Originally copied from std, but since modified to remove poisoning and to |
173 | // support wait. |
174 | // |
175 | // Note: this is intentionally monomorphic |
176 | #[inline (never)] |
177 | fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) { |
178 | let mut curr_queue = queue.load(Ordering::Acquire); |
179 | |
180 | loop { |
181 | let curr_state = strict::addr(curr_queue) & STATE_MASK; |
182 | match (curr_state, &mut init) { |
183 | (COMPLETE, _) => return, |
184 | (INCOMPLETE, Some(init)) => { |
185 | let exchange = queue.compare_exchange( |
186 | curr_queue, |
187 | strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING), |
188 | Ordering::Acquire, |
189 | Ordering::Acquire, |
190 | ); |
191 | if let Err(new_queue) = exchange { |
192 | curr_queue = new_queue; |
193 | continue; |
194 | } |
195 | let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR }; |
196 | if init() { |
197 | guard.new_queue = COMPLETE_PTR; |
198 | } |
199 | return; |
200 | } |
201 | (INCOMPLETE, None) | (RUNNING, _) => { |
202 | wait(queue, curr_queue); |
203 | curr_queue = queue.load(Ordering::Acquire); |
204 | } |
205 | _ => debug_assert!(false), |
206 | } |
207 | } |
208 | } |
209 | |
210 | fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) { |
211 | let curr_state = strict::addr(curr_queue) & STATE_MASK; |
212 | loop { |
213 | let node = Waiter { |
214 | thread: Cell::new(Some(thread::current())), |
215 | signaled: AtomicBool::new(false), |
216 | next: strict::map_addr(curr_queue, |q| q & !STATE_MASK), |
217 | }; |
218 | let me = &node as *const Waiter as *mut Waiter; |
219 | |
220 | let exchange = queue.compare_exchange( |
221 | curr_queue, |
222 | strict::map_addr(me, |q| q | curr_state), |
223 | Ordering::Release, |
224 | Ordering::Relaxed, |
225 | ); |
226 | if let Err(new_queue) = exchange { |
227 | if strict::addr(new_queue) & STATE_MASK != curr_state { |
228 | return; |
229 | } |
230 | curr_queue = new_queue; |
231 | continue; |
232 | } |
233 | |
234 | while !node.signaled.load(Ordering::Acquire) { |
235 | thread::park(); |
236 | } |
237 | break; |
238 | } |
239 | } |
240 | |
241 | // Polyfill of strict provenance from https://crates.io/crates/sptr. |
242 | // |
243 | // Use free-standing function rather than a trait to keep things simple and |
244 | // avoid any potential conflicts with future stabile std API. |
245 | mod strict { |
246 | #[must_use ] |
247 | #[inline ] |
248 | pub(crate) fn addr<T>(ptr: *mut T) -> usize |
249 | where |
250 | T: Sized, |
251 | { |
252 | // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. |
253 | // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the |
254 | // provenance). |
255 | unsafe { core::mem::transmute(ptr) } |
256 | } |
257 | |
258 | #[must_use ] |
259 | #[inline ] |
260 | pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T |
261 | where |
262 | T: Sized, |
263 | { |
264 | // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. |
265 | // |
266 | // In the mean-time, this operation is defined to be "as if" it was |
267 | // a wrapping_offset, so we can emulate it as such. This should properly |
268 | // restore pointer provenance even under today's compiler. |
269 | let self_addr = self::addr(ptr) as isize; |
270 | let dest_addr = addr as isize; |
271 | let offset = dest_addr.wrapping_sub(self_addr); |
272 | |
273 | // This is the canonical desugarring of this operation, |
274 | // but `pointer::cast` was only stabilized in 1.38. |
275 | // self.cast::<u8>().wrapping_offset(offset).cast::<T>() |
276 | (ptr as *mut u8).wrapping_offset(offset) as *mut T |
277 | } |
278 | |
279 | #[must_use ] |
280 | #[inline ] |
281 | pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T |
282 | where |
283 | T: Sized, |
284 | { |
285 | self::with_addr(ptr, f(addr(ptr))) |
286 | } |
287 | } |
288 | |
289 | // These test are snatched from std as well. |
290 | #[cfg (test)] |
291 | mod tests { |
292 | use std::panic; |
293 | use std::{sync::mpsc::channel, thread}; |
294 | |
295 | use super::OnceCell; |
296 | |
297 | impl<T> OnceCell<T> { |
298 | fn init(&self, f: impl FnOnce() -> T) { |
299 | enum Void {} |
300 | let _ = self.initialize(|| Ok::<T, Void>(f())); |
301 | } |
302 | } |
303 | |
304 | #[test ] |
305 | fn smoke_once() { |
306 | static O: OnceCell<()> = OnceCell::new(); |
307 | let mut a = 0; |
308 | O.init(|| a += 1); |
309 | assert_eq!(a, 1); |
310 | O.init(|| a += 1); |
311 | assert_eq!(a, 1); |
312 | } |
313 | |
314 | #[test ] |
315 | fn stampede_once() { |
316 | static O: OnceCell<()> = OnceCell::new(); |
317 | static mut RUN: bool = false; |
318 | |
319 | let (tx, rx) = channel(); |
320 | for _ in 0..10 { |
321 | let tx = tx.clone(); |
322 | thread::spawn(move || { |
323 | for _ in 0..4 { |
324 | thread::yield_now() |
325 | } |
326 | unsafe { |
327 | O.init(|| { |
328 | assert!(!RUN); |
329 | RUN = true; |
330 | }); |
331 | assert!(RUN); |
332 | } |
333 | tx.send(()).unwrap(); |
334 | }); |
335 | } |
336 | |
337 | unsafe { |
338 | O.init(|| { |
339 | assert!(!RUN); |
340 | RUN = true; |
341 | }); |
342 | assert!(RUN); |
343 | } |
344 | |
345 | for _ in 0..10 { |
346 | rx.recv().unwrap(); |
347 | } |
348 | } |
349 | |
350 | #[test ] |
351 | fn poison_bad() { |
352 | static O: OnceCell<()> = OnceCell::new(); |
353 | |
354 | // poison the once |
355 | let t = panic::catch_unwind(|| { |
356 | O.init(|| panic!()); |
357 | }); |
358 | assert!(t.is_err()); |
359 | |
360 | // we can subvert poisoning, however |
361 | let mut called = false; |
362 | O.init(|| { |
363 | called = true; |
364 | }); |
365 | assert!(called); |
366 | |
367 | // once any success happens, we stop propagating the poison |
368 | O.init(|| {}); |
369 | } |
370 | |
371 | #[test ] |
372 | fn wait_for_force_to_finish() { |
373 | static O: OnceCell<()> = OnceCell::new(); |
374 | |
375 | // poison the once |
376 | let t = panic::catch_unwind(|| { |
377 | O.init(|| panic!()); |
378 | }); |
379 | assert!(t.is_err()); |
380 | |
381 | // make sure someone's waiting inside the once via a force |
382 | let (tx1, rx1) = channel(); |
383 | let (tx2, rx2) = channel(); |
384 | let t1 = thread::spawn(move || { |
385 | O.init(|| { |
386 | tx1.send(()).unwrap(); |
387 | rx2.recv().unwrap(); |
388 | }); |
389 | }); |
390 | |
391 | rx1.recv().unwrap(); |
392 | |
393 | // put another waiter on the once |
394 | let t2 = thread::spawn(|| { |
395 | let mut called = false; |
396 | O.init(|| { |
397 | called = true; |
398 | }); |
399 | assert!(!called); |
400 | }); |
401 | |
402 | tx2.send(()).unwrap(); |
403 | |
404 | assert!(t1.join().is_ok()); |
405 | assert!(t2.join().is_ok()); |
406 | } |
407 | |
408 | #[test ] |
409 | #[cfg (target_pointer_width = "64" )] |
410 | fn test_size() { |
411 | use std::mem::size_of; |
412 | |
413 | assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); |
414 | } |
415 | } |
416 | |