| 1 | #![cfg_attr (not(feature = "sync" ), allow(unreachable_pub, dead_code))] |
| 2 | //! # Implementation Details. |
| 3 | //! |
| 4 | //! The semaphore is implemented using an intrusive linked list of waiters. An |
| 5 | //! atomic counter tracks the number of available permits. If the semaphore does |
| 6 | //! not contain the required number of permits, the task attempting to acquire |
| 7 | //! permits places its waker at the end of a queue. When new permits are made |
| 8 | //! available (such as by releasing an initial acquisition), they are assigned |
| 9 | //! to the task at the front of the queue, waking that task if its requested |
| 10 | //! number of permits is met. |
| 11 | //! |
| 12 | //! Because waiters are enqueued at the back of the linked list and dequeued |
| 13 | //! from the front, the semaphore is fair. Tasks trying to acquire large numbers |
| 14 | //! of permits at a time will always be woken eventually, even if many other |
| 15 | //! tasks are acquiring smaller numbers of permits. This means that in a |
| 16 | //! use-case like tokio's read-write lock, writers will not be starved by |
| 17 | //! readers. |
| 18 | use crate::loom::cell::UnsafeCell; |
| 19 | use crate::loom::sync::atomic::AtomicUsize; |
| 20 | use crate::loom::sync::{Mutex, MutexGuard}; |
| 21 | use crate::util::linked_list::{self, LinkedList}; |
| 22 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 23 | use crate::util::trace; |
| 24 | use crate::util::WakeList; |
| 25 | |
| 26 | use std::future::Future; |
| 27 | use std::marker::PhantomPinned; |
| 28 | use std::pin::Pin; |
| 29 | use std::ptr::NonNull; |
| 30 | use std::sync::atomic::Ordering::*; |
| 31 | use std::task::{ready, Context, Poll, Waker}; |
| 32 | use std::{cmp, fmt}; |
| 33 | |
| 34 | /// An asynchronous counting semaphore which permits waiting on multiple permits at once. |
| 35 | pub(crate) struct Semaphore { |
| 36 | waiters: Mutex<Waitlist>, |
| 37 | /// The current number of available permits in the semaphore. |
| 38 | permits: AtomicUsize, |
| 39 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 40 | resource_span: tracing::Span, |
| 41 | } |
| 42 | |
| 43 | struct Waitlist { |
| 44 | queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, |
| 45 | closed: bool, |
| 46 | } |
| 47 | |
| 48 | /// Error returned from the [`Semaphore::try_acquire`] function. |
| 49 | /// |
| 50 | /// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire |
| 51 | #[derive (Debug, PartialEq, Eq)] |
| 52 | pub enum TryAcquireError { |
| 53 | /// The semaphore has been [closed] and cannot issue new permits. |
| 54 | /// |
| 55 | /// [closed]: crate::sync::Semaphore::close |
| 56 | Closed, |
| 57 | |
| 58 | /// The semaphore has no available permits. |
| 59 | NoPermits, |
| 60 | } |
| 61 | /// Error returned from the [`Semaphore::acquire`] function. |
| 62 | /// |
| 63 | /// An `acquire` operation can only fail if the semaphore has been |
| 64 | /// [closed]. |
| 65 | /// |
| 66 | /// [closed]: crate::sync::Semaphore::close |
| 67 | /// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire |
| 68 | #[derive (Debug)] |
| 69 | pub struct AcquireError(()); |
| 70 | |
| 71 | pub(crate) struct Acquire<'a> { |
| 72 | node: Waiter, |
| 73 | semaphore: &'a Semaphore, |
| 74 | num_permits: usize, |
| 75 | queued: bool, |
| 76 | } |
| 77 | |
| 78 | /// An entry in the wait queue. |
| 79 | struct Waiter { |
| 80 | /// The current state of the waiter. |
| 81 | /// |
| 82 | /// This is either the number of remaining permits required by |
| 83 | /// the waiter, or a flag indicating that the waiter is not yet queued. |
| 84 | state: AtomicUsize, |
| 85 | |
| 86 | /// The waker to notify the task awaiting permits. |
| 87 | /// |
| 88 | /// # Safety |
| 89 | /// |
| 90 | /// This may only be accessed while the wait queue is locked. |
| 91 | waker: UnsafeCell<Option<Waker>>, |
| 92 | |
| 93 | /// Intrusive linked-list pointers. |
| 94 | /// |
| 95 | /// # Safety |
| 96 | /// |
| 97 | /// This may only be accessed while the wait queue is locked. |
| 98 | /// |
| 99 | /// TODO: Ideally, we would be able to use loom to enforce that |
| 100 | /// this isn't accessed concurrently. However, it is difficult to |
| 101 | /// use a `UnsafeCell` here, since the `Link` trait requires _returning_ |
| 102 | /// references to `Pointers`, and `UnsafeCell` requires that checked access |
| 103 | /// take place inside a closure. We should consider changing `Pointers` to |
| 104 | /// use `UnsafeCell` internally. |
| 105 | pointers: linked_list::Pointers<Waiter>, |
| 106 | |
| 107 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 108 | ctx: trace::AsyncOpTracingCtx, |
| 109 | |
| 110 | /// Should not be `Unpin`. |
| 111 | _p: PhantomPinned, |
| 112 | } |
| 113 | |
| 114 | generate_addr_of_methods! { |
| 115 | impl<> Waiter { |
| 116 | unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> { |
| 117 | &self.pointers |
| 118 | } |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | impl Semaphore { |
| 123 | /// The maximum number of permits which a semaphore can hold. |
| 124 | /// |
| 125 | /// Note that this reserves three bits of flags in the permit counter, but |
| 126 | /// we only actually use one of them. However, the previous semaphore |
| 127 | /// implementation used three bits, so we will continue to reserve them to |
| 128 | /// avoid a breaking change if additional flags need to be added in the |
| 129 | /// future. |
| 130 | pub(crate) const MAX_PERMITS: usize = usize::MAX >> 3; |
| 131 | const CLOSED: usize = 1; |
| 132 | // The least-significant bit in the number of permits is reserved to use |
| 133 | // as a flag indicating that the semaphore has been closed. Consequently |
| 134 | // PERMIT_SHIFT is used to leave that bit for that purpose. |
| 135 | const PERMIT_SHIFT: usize = 1; |
| 136 | |
| 137 | /// Creates a new semaphore with the initial number of permits |
| 138 | /// |
| 139 | /// Maximum number of permits on 32-bit platforms is `1<<29`. |
| 140 | pub(crate) fn new(permits: usize) -> Self { |
| 141 | assert!( |
| 142 | permits <= Self::MAX_PERMITS, |
| 143 | "a semaphore may not have more than MAX_PERMITS permits ( {})" , |
| 144 | Self::MAX_PERMITS |
| 145 | ); |
| 146 | |
| 147 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 148 | let resource_span = { |
| 149 | let resource_span = tracing::trace_span!( |
| 150 | parent: None, |
| 151 | "runtime.resource" , |
| 152 | concrete_type = "Semaphore" , |
| 153 | kind = "Sync" , |
| 154 | is_internal = true |
| 155 | ); |
| 156 | |
| 157 | resource_span.in_scope(|| { |
| 158 | tracing::trace!( |
| 159 | target: "runtime::resource::state_update" , |
| 160 | permits = permits, |
| 161 | permits.op = "override" , |
| 162 | ) |
| 163 | }); |
| 164 | resource_span |
| 165 | }; |
| 166 | |
| 167 | Self { |
| 168 | permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), |
| 169 | waiters: Mutex::new(Waitlist { |
| 170 | queue: LinkedList::new(), |
| 171 | closed: false, |
| 172 | }), |
| 173 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 174 | resource_span, |
| 175 | } |
| 176 | } |
| 177 | |
| 178 | /// Creates a new semaphore with the initial number of permits. |
| 179 | /// |
| 180 | /// Maximum number of permits on 32-bit platforms is `1<<29`. |
| 181 | #[cfg (not(all(loom, test)))] |
| 182 | pub(crate) const fn const_new(permits: usize) -> Self { |
| 183 | assert!(permits <= Self::MAX_PERMITS); |
| 184 | |
| 185 | Self { |
| 186 | permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), |
| 187 | waiters: Mutex::const_new(Waitlist { |
| 188 | queue: LinkedList::new(), |
| 189 | closed: false, |
| 190 | }), |
| 191 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 192 | resource_span: tracing::Span::none(), |
| 193 | } |
| 194 | } |
| 195 | |
| 196 | /// Creates a new closed semaphore with 0 permits. |
| 197 | pub(crate) fn new_closed() -> Self { |
| 198 | Self { |
| 199 | permits: AtomicUsize::new(Self::CLOSED), |
| 200 | waiters: Mutex::new(Waitlist { |
| 201 | queue: LinkedList::new(), |
| 202 | closed: true, |
| 203 | }), |
| 204 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 205 | resource_span: tracing::Span::none(), |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | /// Creates a new closed semaphore with 0 permits. |
| 210 | #[cfg (not(all(loom, test)))] |
| 211 | pub(crate) const fn const_new_closed() -> Self { |
| 212 | Self { |
| 213 | permits: AtomicUsize::new(Self::CLOSED), |
| 214 | waiters: Mutex::const_new(Waitlist { |
| 215 | queue: LinkedList::new(), |
| 216 | closed: true, |
| 217 | }), |
| 218 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 219 | resource_span: tracing::Span::none(), |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | /// Returns the current number of available permits. |
| 224 | pub(crate) fn available_permits(&self) -> usize { |
| 225 | self.permits.load(Acquire) >> Self::PERMIT_SHIFT |
| 226 | } |
| 227 | |
| 228 | /// Adds `added` new permits to the semaphore. |
| 229 | /// |
| 230 | /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. |
| 231 | pub(crate) fn release(&self, added: usize) { |
| 232 | if added == 0 { |
| 233 | return; |
| 234 | } |
| 235 | |
| 236 | // Assign permits to the wait queue |
| 237 | self.add_permits_locked(added, self.waiters.lock()); |
| 238 | } |
| 239 | |
| 240 | /// Closes the semaphore. This prevents the semaphore from issuing new |
| 241 | /// permits and notifies all pending waiters. |
| 242 | pub(crate) fn close(&self) { |
| 243 | let mut waiters = self.waiters.lock(); |
| 244 | // If the semaphore's permits counter has enough permits for an |
| 245 | // unqueued waiter to acquire all the permits it needs immediately, |
| 246 | // it won't touch the wait list. Therefore, we have to set a bit on |
| 247 | // the permit counter as well. However, we must do this while |
| 248 | // holding the lock --- otherwise, if we set the bit and then wait |
| 249 | // to acquire the lock we'll enter an inconsistent state where the |
| 250 | // permit counter is closed, but the wait list is not. |
| 251 | self.permits.fetch_or(Self::CLOSED, Release); |
| 252 | waiters.closed = true; |
| 253 | while let Some(mut waiter) = waiters.queue.pop_back() { |
| 254 | let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; |
| 255 | if let Some(waker) = waker { |
| 256 | waker.wake(); |
| 257 | } |
| 258 | } |
| 259 | } |
| 260 | |
| 261 | /// Returns true if the semaphore is closed. |
| 262 | pub(crate) fn is_closed(&self) -> bool { |
| 263 | self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED |
| 264 | } |
| 265 | |
| 266 | pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> { |
| 267 | assert!( |
| 268 | num_permits <= Self::MAX_PERMITS, |
| 269 | "a semaphore may not have more than MAX_PERMITS permits ( {})" , |
| 270 | Self::MAX_PERMITS |
| 271 | ); |
| 272 | let num_permits = num_permits << Self::PERMIT_SHIFT; |
| 273 | let mut curr = self.permits.load(Acquire); |
| 274 | loop { |
| 275 | // Has the semaphore closed? |
| 276 | if curr & Self::CLOSED == Self::CLOSED { |
| 277 | return Err(TryAcquireError::Closed); |
| 278 | } |
| 279 | |
| 280 | // Are there enough permits remaining? |
| 281 | if curr < num_permits { |
| 282 | return Err(TryAcquireError::NoPermits); |
| 283 | } |
| 284 | |
| 285 | let next = curr - num_permits; |
| 286 | |
| 287 | match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { |
| 288 | Ok(_) => { |
| 289 | // TODO: Instrument once issue has been solved |
| 290 | return Ok(()); |
| 291 | } |
| 292 | Err(actual) => curr = actual, |
| 293 | } |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> { |
| 298 | Acquire::new(self, num_permits) |
| 299 | } |
| 300 | |
| 301 | /// Release `rem` permits to the semaphore's wait list, starting from the |
| 302 | /// end of the queue. |
| 303 | /// |
| 304 | /// If `rem` exceeds the number of permits needed by the wait list, the |
| 305 | /// remainder are assigned back to the semaphore. |
| 306 | fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { |
| 307 | let mut wakers = WakeList::new(); |
| 308 | let mut lock = Some(waiters); |
| 309 | let mut is_empty = false; |
| 310 | while rem > 0 { |
| 311 | let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); |
| 312 | 'inner: while wakers.can_push() { |
| 313 | // Was the waiter assigned enough permits to wake it? |
| 314 | match waiters.queue.last() { |
| 315 | Some(waiter) => { |
| 316 | if !waiter.assign_permits(&mut rem) { |
| 317 | break 'inner; |
| 318 | } |
| 319 | } |
| 320 | None => { |
| 321 | is_empty = true; |
| 322 | // If we assigned permits to all the waiters in the queue, and there are |
| 323 | // still permits left over, assign them back to the semaphore. |
| 324 | break 'inner; |
| 325 | } |
| 326 | }; |
| 327 | let mut waiter = waiters.queue.pop_back().unwrap(); |
| 328 | if let Some(waker) = |
| 329 | unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } |
| 330 | { |
| 331 | wakers.push(waker); |
| 332 | } |
| 333 | } |
| 334 | |
| 335 | if rem > 0 && is_empty { |
| 336 | let permits = rem; |
| 337 | assert!( |
| 338 | permits <= Self::MAX_PERMITS, |
| 339 | "cannot add more than MAX_PERMITS permits ( {})" , |
| 340 | Self::MAX_PERMITS |
| 341 | ); |
| 342 | let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); |
| 343 | let prev = prev >> Self::PERMIT_SHIFT; |
| 344 | assert!( |
| 345 | prev + permits <= Self::MAX_PERMITS, |
| 346 | "number of added permits ( {}) would overflow MAX_PERMITS ( {})" , |
| 347 | rem, |
| 348 | Self::MAX_PERMITS |
| 349 | ); |
| 350 | |
| 351 | // add remaining permits back |
| 352 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 353 | self.resource_span.in_scope(|| { |
| 354 | tracing::trace!( |
| 355 | target: "runtime::resource::state_update" , |
| 356 | permits = rem, |
| 357 | permits.op = "add" , |
| 358 | ) |
| 359 | }); |
| 360 | |
| 361 | rem = 0; |
| 362 | } |
| 363 | |
| 364 | drop(waiters); // release the lock |
| 365 | |
| 366 | wakers.wake_all(); |
| 367 | } |
| 368 | |
| 369 | assert_eq!(rem, 0); |
| 370 | } |
| 371 | |
| 372 | /// Decrease a semaphore's permits by a maximum of `n`. |
| 373 | /// |
| 374 | /// If there are insufficient permits and it's not possible to reduce by `n`, |
| 375 | /// return the number of permits that were actually reduced. |
| 376 | pub(crate) fn forget_permits(&self, n: usize) -> usize { |
| 377 | if n == 0 { |
| 378 | return 0; |
| 379 | } |
| 380 | |
| 381 | let mut curr_bits = self.permits.load(Acquire); |
| 382 | loop { |
| 383 | let curr = curr_bits >> Self::PERMIT_SHIFT; |
| 384 | let new = curr.saturating_sub(n); |
| 385 | match self.permits.compare_exchange_weak( |
| 386 | curr_bits, |
| 387 | new << Self::PERMIT_SHIFT, |
| 388 | AcqRel, |
| 389 | Acquire, |
| 390 | ) { |
| 391 | Ok(_) => return std::cmp::min(curr, n), |
| 392 | Err(actual) => curr_bits = actual, |
| 393 | }; |
| 394 | } |
| 395 | } |
| 396 | |
| 397 | fn poll_acquire( |
| 398 | &self, |
| 399 | cx: &mut Context<'_>, |
| 400 | num_permits: usize, |
| 401 | node: Pin<&mut Waiter>, |
| 402 | queued: bool, |
| 403 | ) -> Poll<Result<(), AcquireError>> { |
| 404 | let mut acquired = 0; |
| 405 | |
| 406 | let needed = if queued { |
| 407 | node.state.load(Acquire) << Self::PERMIT_SHIFT |
| 408 | } else { |
| 409 | num_permits << Self::PERMIT_SHIFT |
| 410 | }; |
| 411 | |
| 412 | let mut lock = None; |
| 413 | // First, try to take the requested number of permits from the |
| 414 | // semaphore. |
| 415 | let mut curr = self.permits.load(Acquire); |
| 416 | let mut waiters = loop { |
| 417 | // Has the semaphore closed? |
| 418 | if curr & Self::CLOSED > 0 { |
| 419 | return Poll::Ready(Err(AcquireError::closed())); |
| 420 | } |
| 421 | |
| 422 | let mut remaining = 0; |
| 423 | let total = curr |
| 424 | .checked_add(acquired) |
| 425 | .expect("number of permits must not overflow" ); |
| 426 | let (next, acq) = if total >= needed { |
| 427 | let next = curr - (needed - acquired); |
| 428 | (next, needed >> Self::PERMIT_SHIFT) |
| 429 | } else { |
| 430 | remaining = (needed - acquired) - curr; |
| 431 | (0, curr >> Self::PERMIT_SHIFT) |
| 432 | }; |
| 433 | |
| 434 | if remaining > 0 && lock.is_none() { |
| 435 | // No permits were immediately available, so this permit will |
| 436 | // (probably) need to wait. We'll need to acquire a lock on the |
| 437 | // wait queue before continuing. We need to do this _before_ the |
| 438 | // CAS that sets the new value of the semaphore's `permits` |
| 439 | // counter. Otherwise, if we subtract the permits and then |
| 440 | // acquire the lock, we might miss additional permits being |
| 441 | // added while waiting for the lock. |
| 442 | lock = Some(self.waiters.lock()); |
| 443 | } |
| 444 | |
| 445 | match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { |
| 446 | Ok(_) => { |
| 447 | acquired += acq; |
| 448 | if remaining == 0 { |
| 449 | if !queued { |
| 450 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 451 | self.resource_span.in_scope(|| { |
| 452 | tracing::trace!( |
| 453 | target: "runtime::resource::state_update" , |
| 454 | permits = acquired, |
| 455 | permits.op = "sub" , |
| 456 | ); |
| 457 | tracing::trace!( |
| 458 | target: "runtime::resource::async_op::state_update" , |
| 459 | permits_obtained = acquired, |
| 460 | permits.op = "add" , |
| 461 | ) |
| 462 | }); |
| 463 | |
| 464 | return Poll::Ready(Ok(())); |
| 465 | } else if lock.is_none() { |
| 466 | break self.waiters.lock(); |
| 467 | } |
| 468 | } |
| 469 | break lock.expect("lock must be acquired before waiting" ); |
| 470 | } |
| 471 | Err(actual) => curr = actual, |
| 472 | } |
| 473 | }; |
| 474 | |
| 475 | if waiters.closed { |
| 476 | return Poll::Ready(Err(AcquireError::closed())); |
| 477 | } |
| 478 | |
| 479 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 480 | self.resource_span.in_scope(|| { |
| 481 | tracing::trace!( |
| 482 | target: "runtime::resource::state_update" , |
| 483 | permits = acquired, |
| 484 | permits.op = "sub" , |
| 485 | ) |
| 486 | }); |
| 487 | |
| 488 | if node.assign_permits(&mut acquired) { |
| 489 | self.add_permits_locked(acquired, waiters); |
| 490 | return Poll::Ready(Ok(())); |
| 491 | } |
| 492 | |
| 493 | assert_eq!(acquired, 0); |
| 494 | let mut old_waker = None; |
| 495 | |
| 496 | // Otherwise, register the waker & enqueue the node. |
| 497 | node.waker.with_mut(|waker| { |
| 498 | // Safety: the wait list is locked, so we may modify the waker. |
| 499 | let waker = unsafe { &mut *waker }; |
| 500 | // Do we need to register the new waker? |
| 501 | if waker |
| 502 | .as_ref() |
| 503 | .map_or(true, |waker| !waker.will_wake(cx.waker())) |
| 504 | { |
| 505 | old_waker = std::mem::replace(waker, Some(cx.waker().clone())); |
| 506 | } |
| 507 | }); |
| 508 | |
| 509 | // If the waiter is not already in the wait queue, enqueue it. |
| 510 | if !queued { |
| 511 | let node = unsafe { |
| 512 | let node = Pin::into_inner_unchecked(node) as *mut _; |
| 513 | NonNull::new_unchecked(node) |
| 514 | }; |
| 515 | |
| 516 | waiters.queue.push_front(node); |
| 517 | } |
| 518 | drop(waiters); |
| 519 | drop(old_waker); |
| 520 | |
| 521 | Poll::Pending |
| 522 | } |
| 523 | } |
| 524 | |
| 525 | impl fmt::Debug for Semaphore { |
| 526 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 527 | fmt&mut DebugStruct<'_, '_>.debug_struct("Semaphore" ) |
| 528 | .field(name:"permits" , &self.available_permits()) |
| 529 | .finish() |
| 530 | } |
| 531 | } |
| 532 | |
| 533 | impl Waiter { |
| 534 | fn new( |
| 535 | num_permits: usize, |
| 536 | #[cfg (all(tokio_unstable, feature = "tracing" ))] ctx: trace::AsyncOpTracingCtx, |
| 537 | ) -> Self { |
| 538 | Waiter { |
| 539 | waker: UnsafeCell::new(None), |
| 540 | state: AtomicUsize::new(num_permits), |
| 541 | pointers: linked_list::Pointers::new(), |
| 542 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 543 | ctx, |
| 544 | _p: PhantomPinned, |
| 545 | } |
| 546 | } |
| 547 | |
| 548 | /// Assign permits to the waiter. |
| 549 | /// |
| 550 | /// Returns `true` if the waiter should be removed from the queue |
| 551 | fn assign_permits(&self, n: &mut usize) -> bool { |
| 552 | let mut curr = self.state.load(Acquire); |
| 553 | loop { |
| 554 | let assign = cmp::min(curr, *n); |
| 555 | let next = curr - assign; |
| 556 | match self.state.compare_exchange(curr, next, AcqRel, Acquire) { |
| 557 | Ok(_) => { |
| 558 | *n -= assign; |
| 559 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 560 | self.ctx.async_op_span.in_scope(|| { |
| 561 | tracing::trace!( |
| 562 | target: "runtime::resource::async_op::state_update" , |
| 563 | permits_obtained = assign, |
| 564 | permits.op = "add" , |
| 565 | ); |
| 566 | }); |
| 567 | return next == 0; |
| 568 | } |
| 569 | Err(actual) => curr = actual, |
| 570 | } |
| 571 | } |
| 572 | } |
| 573 | } |
| 574 | |
| 575 | impl Future for Acquire<'_> { |
| 576 | type Output = Result<(), AcquireError>; |
| 577 | |
| 578 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 579 | ready!(crate::trace::trace_leaf(cx)); |
| 580 | |
| 581 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 582 | let _resource_span = self.node.ctx.resource_span.clone().entered(); |
| 583 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 584 | let _async_op_span = self.node.ctx.async_op_span.clone().entered(); |
| 585 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 586 | let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered(); |
| 587 | |
| 588 | let (node, semaphore, needed, queued) = self.project(); |
| 589 | |
| 590 | // First, ensure the current task has enough budget to proceed. |
| 591 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 592 | let coop = ready!(trace_poll_op!( |
| 593 | "poll_acquire" , |
| 594 | crate::task::coop::poll_proceed(cx), |
| 595 | )); |
| 596 | |
| 597 | #[cfg (not(all(tokio_unstable, feature = "tracing" )))] |
| 598 | let coop = ready!(crate::task::coop::poll_proceed(cx)); |
| 599 | |
| 600 | let result = match semaphore.poll_acquire(cx, needed, node, *queued) { |
| 601 | Poll::Pending => { |
| 602 | *queued = true; |
| 603 | Poll::Pending |
| 604 | } |
| 605 | Poll::Ready(r) => { |
| 606 | coop.made_progress(); |
| 607 | r?; |
| 608 | *queued = false; |
| 609 | Poll::Ready(Ok(())) |
| 610 | } |
| 611 | }; |
| 612 | |
| 613 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 614 | return trace_poll_op!("poll_acquire" , result); |
| 615 | |
| 616 | #[cfg (not(all(tokio_unstable, feature = "tracing" )))] |
| 617 | return result; |
| 618 | } |
| 619 | } |
| 620 | |
| 621 | impl<'a> Acquire<'a> { |
| 622 | fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self { |
| 623 | #[cfg (any(not(tokio_unstable), not(feature = "tracing" )))] |
| 624 | return Self { |
| 625 | node: Waiter::new(num_permits), |
| 626 | semaphore, |
| 627 | num_permits, |
| 628 | queued: false, |
| 629 | }; |
| 630 | |
| 631 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 632 | return semaphore.resource_span.in_scope(|| { |
| 633 | let async_op_span = |
| 634 | tracing::trace_span!("runtime.resource.async_op" , source = "Acquire::new" ); |
| 635 | let async_op_poll_span = async_op_span.in_scope(|| { |
| 636 | tracing::trace!( |
| 637 | target: "runtime::resource::async_op::state_update" , |
| 638 | permits_requested = num_permits, |
| 639 | permits.op = "override" , |
| 640 | ); |
| 641 | |
| 642 | tracing::trace!( |
| 643 | target: "runtime::resource::async_op::state_update" , |
| 644 | permits_obtained = 0usize, |
| 645 | permits.op = "override" , |
| 646 | ); |
| 647 | |
| 648 | tracing::trace_span!("runtime.resource.async_op.poll" ) |
| 649 | }); |
| 650 | |
| 651 | let ctx = trace::AsyncOpTracingCtx { |
| 652 | async_op_span, |
| 653 | async_op_poll_span, |
| 654 | resource_span: semaphore.resource_span.clone(), |
| 655 | }; |
| 656 | |
| 657 | Self { |
| 658 | node: Waiter::new(num_permits, ctx), |
| 659 | semaphore, |
| 660 | num_permits, |
| 661 | queued: false, |
| 662 | } |
| 663 | }); |
| 664 | } |
| 665 | |
| 666 | fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) { |
| 667 | fn is_unpin<T: Unpin>() {} |
| 668 | unsafe { |
| 669 | // Safety: all fields other than `node` are `Unpin` |
| 670 | |
| 671 | is_unpin::<&Semaphore>(); |
| 672 | is_unpin::<&mut bool>(); |
| 673 | is_unpin::<usize>(); |
| 674 | |
| 675 | let this = self.get_unchecked_mut(); |
| 676 | ( |
| 677 | Pin::new_unchecked(&mut this.node), |
| 678 | this.semaphore, |
| 679 | this.num_permits, |
| 680 | &mut this.queued, |
| 681 | ) |
| 682 | } |
| 683 | } |
| 684 | } |
| 685 | |
| 686 | impl Drop for Acquire<'_> { |
| 687 | fn drop(&mut self) { |
| 688 | // If the future is completed, there is no node in the wait list, so we |
| 689 | // can skip acquiring the lock. |
| 690 | if !self.queued { |
| 691 | return; |
| 692 | } |
| 693 | |
| 694 | // This is where we ensure safety. The future is being dropped, |
| 695 | // which means we must ensure that the waiter entry is no longer stored |
| 696 | // in the linked list. |
| 697 | let mut waiters: MutexGuard<'_, Waitlist> = self.semaphore.waiters.lock(); |
| 698 | |
| 699 | // remove the entry from the list |
| 700 | let node: NonNull = NonNull::from(&mut self.node); |
| 701 | // Safety: we have locked the wait list. |
| 702 | unsafe { waiters.queue.remove(node) }; |
| 703 | |
| 704 | let acquired_permits: usize = self.num_permits - self.node.state.load(order:Acquire); |
| 705 | if acquired_permits > 0 { |
| 706 | self.semaphore.add_permits_locked(rem:acquired_permits, waiters); |
| 707 | } |
| 708 | } |
| 709 | } |
| 710 | |
| 711 | // Safety: the `Acquire` future is not `Sync` automatically because it contains |
| 712 | // a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the |
| 713 | // `UnsafeCell` is only accessed when the future is borrowed mutably (either in |
| 714 | // `poll` or in `drop`). Therefore, it is safe (although not particularly |
| 715 | // _useful_) for the future to be borrowed immutably across threads. |
| 716 | unsafe impl Sync for Acquire<'_> {} |
| 717 | |
| 718 | // ===== impl AcquireError ==== |
| 719 | |
| 720 | impl AcquireError { |
| 721 | fn closed() -> AcquireError { |
| 722 | AcquireError(()) |
| 723 | } |
| 724 | } |
| 725 | |
| 726 | impl fmt::Display for AcquireError { |
| 727 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 728 | write!(fmt, "semaphore closed" ) |
| 729 | } |
| 730 | } |
| 731 | |
| 732 | impl std::error::Error for AcquireError {} |
| 733 | |
| 734 | // ===== impl TryAcquireError ===== |
| 735 | |
| 736 | impl TryAcquireError { |
| 737 | /// Returns `true` if the error was caused by a closed semaphore. |
| 738 | #[allow (dead_code)] // may be used later! |
| 739 | pub(crate) fn is_closed(&self) -> bool { |
| 740 | matches!(self, TryAcquireError::Closed) |
| 741 | } |
| 742 | |
| 743 | /// Returns `true` if the error was caused by calling `try_acquire` on a |
| 744 | /// semaphore with no available permits. |
| 745 | #[allow (dead_code)] // may be used later! |
| 746 | pub(crate) fn is_no_permits(&self) -> bool { |
| 747 | matches!(self, TryAcquireError::NoPermits) |
| 748 | } |
| 749 | } |
| 750 | |
| 751 | impl fmt::Display for TryAcquireError { |
| 752 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 753 | match self { |
| 754 | TryAcquireError::Closed => write!(fmt, "semaphore closed" ), |
| 755 | TryAcquireError::NoPermits => write!(fmt, "no permits available" ), |
| 756 | } |
| 757 | } |
| 758 | } |
| 759 | |
| 760 | impl std::error::Error for TryAcquireError {} |
| 761 | |
| 762 | /// # Safety |
| 763 | /// |
| 764 | /// `Waiter` is forced to be !Unpin. |
| 765 | unsafe impl linked_list::Link for Waiter { |
| 766 | type Handle = NonNull<Waiter>; |
| 767 | type Target = Waiter; |
| 768 | |
| 769 | fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> { |
| 770 | *handle |
| 771 | } |
| 772 | |
| 773 | unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { |
| 774 | ptr |
| 775 | } |
| 776 | |
| 777 | unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { |
| 778 | Waiter::addr_of_pointers(me:target) |
| 779 | } |
| 780 | } |
| 781 | |