1 | #![cfg_attr (not(feature = "sync" ), allow(unreachable_pub, dead_code))] |
2 | //! # Implementation Details. |
3 | //! |
4 | //! The semaphore is implemented using an intrusive linked list of waiters. An |
5 | //! atomic counter tracks the number of available permits. If the semaphore does |
6 | //! not contain the required number of permits, the task attempting to acquire |
7 | //! permits places its waker at the end of a queue. When new permits are made |
8 | //! available (such as by releasing an initial acquisition), they are assigned |
9 | //! to the task at the front of the queue, waking that task if its requested |
10 | //! number of permits is met. |
11 | //! |
12 | //! Because waiters are enqueued at the back of the linked list and dequeued |
13 | //! from the front, the semaphore is fair. Tasks trying to acquire large numbers |
14 | //! of permits at a time will always be woken eventually, even if many other |
15 | //! tasks are acquiring smaller numbers of permits. This means that in a |
16 | //! use-case like tokio's read-write lock, writers will not be starved by |
17 | //! readers. |
18 | use crate::loom::cell::UnsafeCell; |
19 | use crate::loom::sync::atomic::AtomicUsize; |
20 | use crate::loom::sync::{Mutex, MutexGuard}; |
21 | use crate::util::linked_list::{self, LinkedList}; |
22 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
23 | use crate::util::trace; |
24 | use crate::util::WakeList; |
25 | |
26 | use std::future::Future; |
27 | use std::marker::PhantomPinned; |
28 | use std::pin::Pin; |
29 | use std::ptr::NonNull; |
30 | use std::sync::atomic::Ordering::*; |
31 | use std::task::{ready, Context, Poll, Waker}; |
32 | use std::{cmp, fmt}; |
33 | |
34 | /// An asynchronous counting semaphore which permits waiting on multiple permits at once. |
35 | pub(crate) struct Semaphore { |
36 | waiters: Mutex<Waitlist>, |
37 | /// The current number of available permits in the semaphore. |
38 | permits: AtomicUsize, |
39 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
40 | resource_span: tracing::Span, |
41 | } |
42 | |
43 | struct Waitlist { |
44 | queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, |
45 | closed: bool, |
46 | } |
47 | |
48 | /// Error returned from the [`Semaphore::try_acquire`] function. |
49 | /// |
50 | /// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire |
51 | #[derive (Debug, PartialEq, Eq)] |
52 | pub enum TryAcquireError { |
53 | /// The semaphore has been [closed] and cannot issue new permits. |
54 | /// |
55 | /// [closed]: crate::sync::Semaphore::close |
56 | Closed, |
57 | |
58 | /// The semaphore has no available permits. |
59 | NoPermits, |
60 | } |
61 | /// Error returned from the [`Semaphore::acquire`] function. |
62 | /// |
63 | /// An `acquire` operation can only fail if the semaphore has been |
64 | /// [closed]. |
65 | /// |
66 | /// [closed]: crate::sync::Semaphore::close |
67 | /// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire |
68 | #[derive (Debug)] |
69 | pub struct AcquireError(()); |
70 | |
71 | pub(crate) struct Acquire<'a> { |
72 | node: Waiter, |
73 | semaphore: &'a Semaphore, |
74 | num_permits: usize, |
75 | queued: bool, |
76 | } |
77 | |
78 | /// An entry in the wait queue. |
79 | struct Waiter { |
80 | /// The current state of the waiter. |
81 | /// |
82 | /// This is either the number of remaining permits required by |
83 | /// the waiter, or a flag indicating that the waiter is not yet queued. |
84 | state: AtomicUsize, |
85 | |
86 | /// The waker to notify the task awaiting permits. |
87 | /// |
88 | /// # Safety |
89 | /// |
90 | /// This may only be accessed while the wait queue is locked. |
91 | waker: UnsafeCell<Option<Waker>>, |
92 | |
93 | /// Intrusive linked-list pointers. |
94 | /// |
95 | /// # Safety |
96 | /// |
97 | /// This may only be accessed while the wait queue is locked. |
98 | /// |
99 | /// TODO: Ideally, we would be able to use loom to enforce that |
100 | /// this isn't accessed concurrently. However, it is difficult to |
101 | /// use a `UnsafeCell` here, since the `Link` trait requires _returning_ |
102 | /// references to `Pointers`, and `UnsafeCell` requires that checked access |
103 | /// take place inside a closure. We should consider changing `Pointers` to |
104 | /// use `UnsafeCell` internally. |
105 | pointers: linked_list::Pointers<Waiter>, |
106 | |
107 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
108 | ctx: trace::AsyncOpTracingCtx, |
109 | |
110 | /// Should not be `Unpin`. |
111 | _p: PhantomPinned, |
112 | } |
113 | |
114 | generate_addr_of_methods! { |
115 | impl<> Waiter { |
116 | unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { |
117 | &self.pointers |
118 | } |
119 | } |
120 | } |
121 | |
122 | impl Semaphore { |
123 | /// The maximum number of permits which a semaphore can hold. |
124 | /// |
125 | /// Note that this reserves three bits of flags in the permit counter, but |
126 | /// we only actually use one of them. However, the previous semaphore |
127 | /// implementation used three bits, so we will continue to reserve them to |
128 | /// avoid a breaking change if additional flags need to be added in the |
129 | /// future. |
130 | pub(crate) const MAX_PERMITS: usize = usize::MAX >> 3; |
131 | const CLOSED: usize = 1; |
132 | // The least-significant bit in the number of permits is reserved to use |
133 | // as a flag indicating that the semaphore has been closed. Consequently |
134 | // PERMIT_SHIFT is used to leave that bit for that purpose. |
135 | const PERMIT_SHIFT: usize = 1; |
136 | |
137 | /// Creates a new semaphore with the initial number of permits |
138 | /// |
139 | /// Maximum number of permits on 32-bit platforms is `1<<29`. |
140 | pub(crate) fn new(permits: usize) -> Self { |
141 | assert!( |
142 | permits <= Self::MAX_PERMITS, |
143 | "a semaphore may not have more than MAX_PERMITS permits ( {})" , |
144 | Self::MAX_PERMITS |
145 | ); |
146 | |
147 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
148 | let resource_span = { |
149 | let resource_span = tracing::trace_span!( |
150 | parent: None, |
151 | "runtime.resource" , |
152 | concrete_type = "Semaphore" , |
153 | kind = "Sync" , |
154 | is_internal = true |
155 | ); |
156 | |
157 | resource_span.in_scope(|| { |
158 | tracing::trace!( |
159 | target: "runtime::resource::state_update" , |
160 | permits = permits, |
161 | permits.op = "override" , |
162 | ) |
163 | }); |
164 | resource_span |
165 | }; |
166 | |
167 | Self { |
168 | permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), |
169 | waiters: Mutex::new(Waitlist { |
170 | queue: LinkedList::new(), |
171 | closed: false, |
172 | }), |
173 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
174 | resource_span, |
175 | } |
176 | } |
177 | |
178 | /// Creates a new semaphore with the initial number of permits. |
179 | /// |
180 | /// Maximum number of permits on 32-bit platforms is `1<<29`. |
181 | #[cfg (not(all(loom, test)))] |
182 | pub(crate) const fn const_new(permits: usize) -> Self { |
183 | assert!(permits <= Self::MAX_PERMITS); |
184 | |
185 | Self { |
186 | permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), |
187 | waiters: Mutex::const_new(Waitlist { |
188 | queue: LinkedList::new(), |
189 | closed: false, |
190 | }), |
191 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
192 | resource_span: tracing::Span::none(), |
193 | } |
194 | } |
195 | |
196 | /// Creates a new closed semaphore with 0 permits. |
197 | pub(crate) fn new_closed() -> Self { |
198 | Self { |
199 | permits: AtomicUsize::new(Self::CLOSED), |
200 | waiters: Mutex::new(Waitlist { |
201 | queue: LinkedList::new(), |
202 | closed: true, |
203 | }), |
204 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
205 | resource_span: tracing::Span::none(), |
206 | } |
207 | } |
208 | |
209 | /// Creates a new closed semaphore with 0 permits. |
210 | #[cfg (not(all(loom, test)))] |
211 | pub(crate) const fn const_new_closed() -> Self { |
212 | Self { |
213 | permits: AtomicUsize::new(Self::CLOSED), |
214 | waiters: Mutex::const_new(Waitlist { |
215 | queue: LinkedList::new(), |
216 | closed: true, |
217 | }), |
218 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
219 | resource_span: tracing::Span::none(), |
220 | } |
221 | } |
222 | |
223 | /// Returns the current number of available permits. |
224 | pub(crate) fn available_permits(&self) -> usize { |
225 | self.permits.load(Acquire) >> Self::PERMIT_SHIFT |
226 | } |
227 | |
228 | /// Adds `added` new permits to the semaphore. |
229 | /// |
230 | /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. |
231 | pub(crate) fn release(&self, added: usize) { |
232 | if added == 0 { |
233 | return; |
234 | } |
235 | |
236 | // Assign permits to the wait queue |
237 | self.add_permits_locked(added, self.waiters.lock()); |
238 | } |
239 | |
240 | /// Closes the semaphore. This prevents the semaphore from issuing new |
241 | /// permits and notifies all pending waiters. |
242 | pub(crate) fn close(&self) { |
243 | let mut waiters = self.waiters.lock(); |
244 | // If the semaphore's permits counter has enough permits for an |
245 | // unqueued waiter to acquire all the permits it needs immediately, |
246 | // it won't touch the wait list. Therefore, we have to set a bit on |
247 | // the permit counter as well. However, we must do this while |
248 | // holding the lock --- otherwise, if we set the bit and then wait |
249 | // to acquire the lock we'll enter an inconsistent state where the |
250 | // permit counter is closed, but the wait list is not. |
251 | self.permits.fetch_or(Self::CLOSED, Release); |
252 | waiters.closed = true; |
253 | while let Some(mut waiter) = waiters.queue.pop_back() { |
254 | let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; |
255 | if let Some(waker) = waker { |
256 | waker.wake(); |
257 | } |
258 | } |
259 | } |
260 | |
261 | /// Returns true if the semaphore is closed. |
262 | pub(crate) fn is_closed(&self) -> bool { |
263 | self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED |
264 | } |
265 | |
266 | pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> { |
267 | assert!( |
268 | num_permits <= Self::MAX_PERMITS, |
269 | "a semaphore may not have more than MAX_PERMITS permits ( {})" , |
270 | Self::MAX_PERMITS |
271 | ); |
272 | let num_permits = num_permits << Self::PERMIT_SHIFT; |
273 | let mut curr = self.permits.load(Acquire); |
274 | loop { |
275 | // Has the semaphore closed? |
276 | if curr & Self::CLOSED == Self::CLOSED { |
277 | return Err(TryAcquireError::Closed); |
278 | } |
279 | |
280 | // Are there enough permits remaining? |
281 | if curr < num_permits { |
282 | return Err(TryAcquireError::NoPermits); |
283 | } |
284 | |
285 | let next = curr - num_permits; |
286 | |
287 | match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { |
288 | Ok(_) => { |
289 | // TODO: Instrument once issue has been solved |
290 | return Ok(()); |
291 | } |
292 | Err(actual) => curr = actual, |
293 | } |
294 | } |
295 | } |
296 | |
297 | pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> { |
298 | Acquire::new(self, num_permits) |
299 | } |
300 | |
301 | /// Release `rem` permits to the semaphore's wait list, starting from the |
302 | /// end of the queue. |
303 | /// |
304 | /// If `rem` exceeds the number of permits needed by the wait list, the |
305 | /// remainder are assigned back to the semaphore. |
306 | fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { |
307 | let mut wakers = WakeList::new(); |
308 | let mut lock = Some(waiters); |
309 | let mut is_empty = false; |
310 | while rem > 0 { |
311 | let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); |
312 | 'inner: while wakers.can_push() { |
313 | // Was the waiter assigned enough permits to wake it? |
314 | match waiters.queue.last() { |
315 | Some(waiter) => { |
316 | if !waiter.assign_permits(&mut rem) { |
317 | break 'inner; |
318 | } |
319 | } |
320 | None => { |
321 | is_empty = true; |
322 | // If we assigned permits to all the waiters in the queue, and there are |
323 | // still permits left over, assign them back to the semaphore. |
324 | break 'inner; |
325 | } |
326 | }; |
327 | let mut waiter = waiters.queue.pop_back().unwrap(); |
328 | if let Some(waker) = |
329 | unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } |
330 | { |
331 | wakers.push(waker); |
332 | } |
333 | } |
334 | |
335 | if rem > 0 && is_empty { |
336 | let permits = rem; |
337 | assert!( |
338 | permits <= Self::MAX_PERMITS, |
339 | "cannot add more than MAX_PERMITS permits ( {})" , |
340 | Self::MAX_PERMITS |
341 | ); |
342 | let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); |
343 | let prev = prev >> Self::PERMIT_SHIFT; |
344 | assert!( |
345 | prev + permits <= Self::MAX_PERMITS, |
346 | "number of added permits ( {}) would overflow MAX_PERMITS ( {})" , |
347 | rem, |
348 | Self::MAX_PERMITS |
349 | ); |
350 | |
351 | // add remaining permits back |
352 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
353 | self.resource_span.in_scope(|| { |
354 | tracing::trace!( |
355 | target: "runtime::resource::state_update" , |
356 | permits = rem, |
357 | permits.op = "add" , |
358 | ) |
359 | }); |
360 | |
361 | rem = 0; |
362 | } |
363 | |
364 | drop(waiters); // release the lock |
365 | |
366 | wakers.wake_all(); |
367 | } |
368 | |
369 | assert_eq!(rem, 0); |
370 | } |
371 | |
372 | /// Decrease a semaphore's permits by a maximum of `n`. |
373 | /// |
374 | /// If there are insufficient permits and it's not possible to reduce by `n`, |
375 | /// return the number of permits that were actually reduced. |
376 | pub(crate) fn forget_permits(&self, n: usize) -> usize { |
377 | if n == 0 { |
378 | return 0; |
379 | } |
380 | |
381 | let mut curr_bits = self.permits.load(Acquire); |
382 | loop { |
383 | let curr = curr_bits >> Self::PERMIT_SHIFT; |
384 | let new = curr.saturating_sub(n); |
385 | match self.permits.compare_exchange_weak( |
386 | curr_bits, |
387 | new << Self::PERMIT_SHIFT, |
388 | AcqRel, |
389 | Acquire, |
390 | ) { |
391 | Ok(_) => return std::cmp::min(curr, n), |
392 | Err(actual) => curr_bits = actual, |
393 | }; |
394 | } |
395 | } |
396 | |
397 | fn poll_acquire( |
398 | &self, |
399 | cx: &mut Context<'_>, |
400 | num_permits: usize, |
401 | node: Pin<&mut Waiter>, |
402 | queued: bool, |
403 | ) -> Poll<Result<(), AcquireError>> { |
404 | let mut acquired = 0; |
405 | |
406 | let needed = if queued { |
407 | node.state.load(Acquire) << Self::PERMIT_SHIFT |
408 | } else { |
409 | num_permits << Self::PERMIT_SHIFT |
410 | }; |
411 | |
412 | let mut lock = None; |
413 | // First, try to take the requested number of permits from the |
414 | // semaphore. |
415 | let mut curr = self.permits.load(Acquire); |
416 | let mut waiters = loop { |
417 | // Has the semaphore closed? |
418 | if curr & Self::CLOSED > 0 { |
419 | return Poll::Ready(Err(AcquireError::closed())); |
420 | } |
421 | |
422 | let mut remaining = 0; |
423 | let total = curr |
424 | .checked_add(acquired) |
425 | .expect("number of permits must not overflow" ); |
426 | let (next, acq) = if total >= needed { |
427 | let next = curr - (needed - acquired); |
428 | (next, needed >> Self::PERMIT_SHIFT) |
429 | } else { |
430 | remaining = (needed - acquired) - curr; |
431 | (0, curr >> Self::PERMIT_SHIFT) |
432 | }; |
433 | |
434 | if remaining > 0 && lock.is_none() { |
435 | // No permits were immediately available, so this permit will |
436 | // (probably) need to wait. We'll need to acquire a lock on the |
437 | // wait queue before continuing. We need to do this _before_ the |
438 | // CAS that sets the new value of the semaphore's `permits` |
439 | // counter. Otherwise, if we subtract the permits and then |
440 | // acquire the lock, we might miss additional permits being |
441 | // added while waiting for the lock. |
442 | lock = Some(self.waiters.lock()); |
443 | } |
444 | |
445 | match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { |
446 | Ok(_) => { |
447 | acquired += acq; |
448 | if remaining == 0 { |
449 | if !queued { |
450 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
451 | self.resource_span.in_scope(|| { |
452 | tracing::trace!( |
453 | target: "runtime::resource::state_update" , |
454 | permits = acquired, |
455 | permits.op = "sub" , |
456 | ); |
457 | tracing::trace!( |
458 | target: "runtime::resource::async_op::state_update" , |
459 | permits_obtained = acquired, |
460 | permits.op = "add" , |
461 | ) |
462 | }); |
463 | |
464 | return Poll::Ready(Ok(())); |
465 | } else if lock.is_none() { |
466 | break self.waiters.lock(); |
467 | } |
468 | } |
469 | break lock.expect("lock must be acquired before waiting" ); |
470 | } |
471 | Err(actual) => curr = actual, |
472 | } |
473 | }; |
474 | |
475 | if waiters.closed { |
476 | return Poll::Ready(Err(AcquireError::closed())); |
477 | } |
478 | |
479 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
480 | self.resource_span.in_scope(|| { |
481 | tracing::trace!( |
482 | target: "runtime::resource::state_update" , |
483 | permits = acquired, |
484 | permits.op = "sub" , |
485 | ) |
486 | }); |
487 | |
488 | if node.assign_permits(&mut acquired) { |
489 | self.add_permits_locked(acquired, waiters); |
490 | return Poll::Ready(Ok(())); |
491 | } |
492 | |
493 | assert_eq!(acquired, 0); |
494 | let mut old_waker = None; |
495 | |
496 | // Otherwise, register the waker & enqueue the node. |
497 | node.waker.with_mut(|waker| { |
498 | // Safety: the wait list is locked, so we may modify the waker. |
499 | let waker = unsafe { &mut *waker }; |
500 | // Do we need to register the new waker? |
501 | if waker |
502 | .as_ref() |
503 | .map_or(true, |waker| !waker.will_wake(cx.waker())) |
504 | { |
505 | old_waker = std::mem::replace(waker, Some(cx.waker().clone())); |
506 | } |
507 | }); |
508 | |
509 | // If the waiter is not already in the wait queue, enqueue it. |
510 | if !queued { |
511 | let node = unsafe { |
512 | let node = Pin::into_inner_unchecked(node) as *mut _; |
513 | NonNull::new_unchecked(node) |
514 | }; |
515 | |
516 | waiters.queue.push_front(node); |
517 | } |
518 | drop(waiters); |
519 | drop(old_waker); |
520 | |
521 | Poll::Pending |
522 | } |
523 | } |
524 | |
525 | impl fmt::Debug for Semaphore { |
526 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
527 | fmt&mut DebugStruct<'_, '_>.debug_struct("Semaphore" ) |
528 | .field(name:"permits" , &self.available_permits()) |
529 | .finish() |
530 | } |
531 | } |
532 | |
533 | impl Waiter { |
534 | fn new( |
535 | num_permits: usize, |
536 | #[cfg (all(tokio_unstable, feature = "tracing" ))] ctx: trace::AsyncOpTracingCtx, |
537 | ) -> Self { |
538 | Waiter { |
539 | waker: UnsafeCell::new(None), |
540 | state: AtomicUsize::new(num_permits), |
541 | pointers: linked_list::Pointers::new(), |
542 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
543 | ctx, |
544 | _p: PhantomPinned, |
545 | } |
546 | } |
547 | |
548 | /// Assign permits to the waiter. |
549 | /// |
550 | /// Returns `true` if the waiter should be removed from the queue |
551 | fn assign_permits(&self, n: &mut usize) -> bool { |
552 | let mut curr = self.state.load(Acquire); |
553 | loop { |
554 | let assign = cmp::min(curr, *n); |
555 | let next = curr - assign; |
556 | match self.state.compare_exchange(curr, next, AcqRel, Acquire) { |
557 | Ok(_) => { |
558 | *n -= assign; |
559 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
560 | self.ctx.async_op_span.in_scope(|| { |
561 | tracing::trace!( |
562 | target: "runtime::resource::async_op::state_update" , |
563 | permits_obtained = assign, |
564 | permits.op = "add" , |
565 | ); |
566 | }); |
567 | return next == 0; |
568 | } |
569 | Err(actual) => curr = actual, |
570 | } |
571 | } |
572 | } |
573 | } |
574 | |
575 | impl Future for Acquire<'_> { |
576 | type Output = Result<(), AcquireError>; |
577 | |
578 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
579 | ready!(crate::trace::trace_leaf(cx)); |
580 | |
581 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
582 | let _resource_span = self.node.ctx.resource_span.clone().entered(); |
583 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
584 | let _async_op_span = self.node.ctx.async_op_span.clone().entered(); |
585 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
586 | let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered(); |
587 | |
588 | let (node, semaphore, needed, queued) = self.project(); |
589 | |
590 | // First, ensure the current task has enough budget to proceed. |
591 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
592 | let coop = ready!(trace_poll_op!( |
593 | "poll_acquire" , |
594 | crate::task::coop::poll_proceed(cx), |
595 | )); |
596 | |
597 | #[cfg (not(all(tokio_unstable, feature = "tracing" )))] |
598 | let coop = ready!(crate::task::coop::poll_proceed(cx)); |
599 | |
600 | let result = match semaphore.poll_acquire(cx, needed, node, *queued) { |
601 | Poll::Pending => { |
602 | *queued = true; |
603 | Poll::Pending |
604 | } |
605 | Poll::Ready(r) => { |
606 | coop.made_progress(); |
607 | r?; |
608 | *queued = false; |
609 | Poll::Ready(Ok(())) |
610 | } |
611 | }; |
612 | |
613 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
614 | return trace_poll_op!("poll_acquire" , result); |
615 | |
616 | #[cfg (not(all(tokio_unstable, feature = "tracing" )))] |
617 | return result; |
618 | } |
619 | } |
620 | |
621 | impl<'a> Acquire<'a> { |
622 | fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self { |
623 | #[cfg (any(not(tokio_unstable), not(feature = "tracing" )))] |
624 | return Self { |
625 | node: Waiter::new(num_permits), |
626 | semaphore, |
627 | num_permits, |
628 | queued: false, |
629 | }; |
630 | |
631 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
632 | return semaphore.resource_span.in_scope(|| { |
633 | let async_op_span = |
634 | tracing::trace_span!("runtime.resource.async_op" , source = "Acquire::new" ); |
635 | let async_op_poll_span = async_op_span.in_scope(|| { |
636 | tracing::trace!( |
637 | target: "runtime::resource::async_op::state_update" , |
638 | permits_requested = num_permits, |
639 | permits.op = "override" , |
640 | ); |
641 | |
642 | tracing::trace!( |
643 | target: "runtime::resource::async_op::state_update" , |
644 | permits_obtained = 0usize, |
645 | permits.op = "override" , |
646 | ); |
647 | |
648 | tracing::trace_span!("runtime.resource.async_op.poll" ) |
649 | }); |
650 | |
651 | let ctx = trace::AsyncOpTracingCtx { |
652 | async_op_span, |
653 | async_op_poll_span, |
654 | resource_span: semaphore.resource_span.clone(), |
655 | }; |
656 | |
657 | Self { |
658 | node: Waiter::new(num_permits, ctx), |
659 | semaphore, |
660 | num_permits, |
661 | queued: false, |
662 | } |
663 | }); |
664 | } |
665 | |
666 | fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) { |
667 | fn is_unpin<T: Unpin>() {} |
668 | unsafe { |
669 | // Safety: all fields other than `node` are `Unpin` |
670 | |
671 | is_unpin::<&Semaphore>(); |
672 | is_unpin::<&mut bool>(); |
673 | is_unpin::<usize>(); |
674 | |
675 | let this = self.get_unchecked_mut(); |
676 | ( |
677 | Pin::new_unchecked(&mut this.node), |
678 | this.semaphore, |
679 | this.num_permits, |
680 | &mut this.queued, |
681 | ) |
682 | } |
683 | } |
684 | } |
685 | |
686 | impl Drop for Acquire<'_> { |
687 | fn drop(&mut self) { |
688 | // If the future is completed, there is no node in the wait list, so we |
689 | // can skip acquiring the lock. |
690 | if !self.queued { |
691 | return; |
692 | } |
693 | |
694 | // This is where we ensure safety. The future is being dropped, |
695 | // which means we must ensure that the waiter entry is no longer stored |
696 | // in the linked list. |
697 | let mut waiters: MutexGuard<'_, Waitlist> = self.semaphore.waiters.lock(); |
698 | |
699 | // remove the entry from the list |
700 | let node: NonNull = NonNull::from(&mut self.node); |
701 | // Safety: we have locked the wait list. |
702 | unsafe { waiters.queue.remove(node) }; |
703 | |
704 | let acquired_permits: usize = self.num_permits - self.node.state.load(order:Acquire); |
705 | if acquired_permits > 0 { |
706 | self.semaphore.add_permits_locked(rem:acquired_permits, waiters); |
707 | } |
708 | } |
709 | } |
710 | |
711 | // Safety: the `Acquire` future is not `Sync` automatically because it contains |
712 | // a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the |
713 | // `UnsafeCell` is only accessed when the future is borrowed mutably (either in |
714 | // `poll` or in `drop`). Therefore, it is safe (although not particularly |
715 | // _useful_) for the future to be borrowed immutably across threads. |
716 | unsafe impl Sync for Acquire<'_> {} |
717 | |
718 | // ===== impl AcquireError ==== |
719 | |
720 | impl AcquireError { |
721 | fn closed() -> AcquireError { |
722 | AcquireError(()) |
723 | } |
724 | } |
725 | |
726 | impl fmt::Display for AcquireError { |
727 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
728 | write!(fmt, "semaphore closed" ) |
729 | } |
730 | } |
731 | |
732 | impl std::error::Error for AcquireError {} |
733 | |
734 | // ===== impl TryAcquireError ===== |
735 | |
736 | impl TryAcquireError { |
737 | /// Returns `true` if the error was caused by a closed semaphore. |
738 | #[allow (dead_code)] // may be used later! |
739 | pub(crate) fn is_closed(&self) -> bool { |
740 | matches!(self, TryAcquireError::Closed) |
741 | } |
742 | |
743 | /// Returns `true` if the error was caused by calling `try_acquire` on a |
744 | /// semaphore with no available permits. |
745 | #[allow (dead_code)] // may be used later! |
746 | pub(crate) fn is_no_permits(&self) -> bool { |
747 | matches!(self, TryAcquireError::NoPermits) |
748 | } |
749 | } |
750 | |
751 | impl fmt::Display for TryAcquireError { |
752 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
753 | match self { |
754 | TryAcquireError::Closed => write!(fmt, "semaphore closed" ), |
755 | TryAcquireError::NoPermits => write!(fmt, "no permits available" ), |
756 | } |
757 | } |
758 | } |
759 | |
760 | impl std::error::Error for TryAcquireError {} |
761 | |
762 | /// # Safety |
763 | /// |
764 | /// `Waiter` is forced to be !Unpin. |
765 | unsafe impl linked_list::Link for Waiter { |
766 | type Handle = NonNull<Waiter>; |
767 | type Target = Waiter; |
768 | |
769 | fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> { |
770 | *handle |
771 | } |
772 | |
773 | unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { |
774 | ptr |
775 | } |
776 | |
777 | unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { |
778 | Waiter::addr_of_pointers(me:target) |
779 | } |
780 | } |
781 | |