| 1 | //! A synchronization primitive for controlling access to a pool of resources. |
| 2 | use core::cell::{Cell, RefCell}; |
| 3 | use core::convert::Infallible; |
| 4 | use core::future::{poll_fn, Future}; |
| 5 | use core::task::{Poll, Waker}; |
| 6 | |
| 7 | use heapless::Deque; |
| 8 | |
| 9 | use crate::blocking_mutex::raw::RawMutex; |
| 10 | use crate::blocking_mutex::Mutex; |
| 11 | use crate::waitqueue::WakerRegistration; |
| 12 | |
| 13 | /// An asynchronous semaphore. |
| 14 | /// |
| 15 | /// A semaphore tracks a number of permits, typically representing a pool of shared resources. |
| 16 | /// Users can acquire permits to synchronize access to those resources. The semaphore does not |
| 17 | /// contain the resources themselves, only the count of available permits. |
| 18 | pub trait Semaphore: Sized { |
| 19 | /// The error returned when the semaphore is unable to acquire the requested permits. |
| 20 | type Error; |
| 21 | |
| 22 | /// Asynchronously acquire one or more permits from the semaphore. |
| 23 | async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>; |
| 24 | |
| 25 | /// Try to immediately acquire one or more permits from the semaphore. |
| 26 | fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>>; |
| 27 | |
| 28 | /// Asynchronously acquire all permits controlled by the semaphore. |
| 29 | /// |
| 30 | /// This method will wait until at least `min` permits are available, then acquire all available permits |
| 31 | /// from the semaphore. Note that other tasks may have already acquired some permits which could be released |
| 32 | /// back to the semaphore at any time. The number of permits actually acquired may be determined by calling |
| 33 | /// [`SemaphoreReleaser::permits`]. |
| 34 | async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>; |
| 35 | |
| 36 | /// Try to immediately acquire all available permits from the semaphore, if at least `min` permits are available. |
| 37 | fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>>; |
| 38 | |
| 39 | /// Release `permits` back to the semaphore, making them available to be acquired. |
| 40 | fn release(&self, permits: usize); |
| 41 | |
| 42 | /// Reset the number of available permints in the semaphore to `permits`. |
| 43 | fn set(&self, permits: usize); |
| 44 | } |
| 45 | |
| 46 | /// A representation of a number of acquired permits. |
| 47 | /// |
| 48 | /// The acquired permits will be released back to the [`Semaphore`] when this is dropped. |
| 49 | pub struct SemaphoreReleaser<'a, S: Semaphore> { |
| 50 | semaphore: &'a S, |
| 51 | permits: usize, |
| 52 | } |
| 53 | |
| 54 | impl<'a, S: Semaphore> Drop for SemaphoreReleaser<'a, S> { |
| 55 | fn drop(&mut self) { |
| 56 | self.semaphore.release(self.permits); |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | impl<'a, S: Semaphore> SemaphoreReleaser<'a, S> { |
| 61 | /// The number of acquired permits. |
| 62 | pub fn permits(&self) -> usize { |
| 63 | self.permits |
| 64 | } |
| 65 | |
| 66 | /// Prevent the acquired permits from being released on drop. |
| 67 | /// |
| 68 | /// Returns the number of acquired permits. |
| 69 | pub fn disarm(self) -> usize { |
| 70 | let permits: usize = self.permits; |
| 71 | core::mem::forget(self); |
| 72 | permits |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | /// A greedy [`Semaphore`] implementation. |
| 77 | /// |
| 78 | /// Tasks can acquire permits as soon as they become available, even if another task |
| 79 | /// is waiting on a larger number of permits. |
| 80 | pub struct GreedySemaphore<M: RawMutex> { |
| 81 | state: Mutex<M, Cell<SemaphoreState>>, |
| 82 | } |
| 83 | |
| 84 | impl<M: RawMutex> Default for GreedySemaphore<M> { |
| 85 | fn default() -> Self { |
| 86 | Self::new(permits:0) |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | impl<M: RawMutex> GreedySemaphore<M> { |
| 91 | /// Create a new `Semaphore`. |
| 92 | pub const fn new(permits: usize) -> Self { |
| 93 | Self { |
| 94 | state: Mutex::new(Cell::new(SemaphoreState { |
| 95 | permits, |
| 96 | waker: WakerRegistration::new(), |
| 97 | })), |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | #[cfg (test)] |
| 102 | fn permits(&self) -> usize { |
| 103 | self.state.lock(|cell| { |
| 104 | let state = cell.replace(SemaphoreState::EMPTY); |
| 105 | let permits = state.permits; |
| 106 | cell.replace(state); |
| 107 | permits |
| 108 | }) |
| 109 | } |
| 110 | |
| 111 | fn poll_acquire( |
| 112 | &self, |
| 113 | permits: usize, |
| 114 | acquire_all: bool, |
| 115 | waker: Option<&Waker>, |
| 116 | ) -> Poll<Result<SemaphoreReleaser<'_, Self>, Infallible>> { |
| 117 | self.state.lock(|cell| { |
| 118 | let mut state = cell.replace(SemaphoreState::EMPTY); |
| 119 | if let Some(permits) = state.take(permits, acquire_all) { |
| 120 | cell.set(state); |
| 121 | Poll::Ready(Ok(SemaphoreReleaser { |
| 122 | semaphore: self, |
| 123 | permits, |
| 124 | })) |
| 125 | } else { |
| 126 | if let Some(waker) = waker { |
| 127 | state.register(waker); |
| 128 | } |
| 129 | cell.set(state); |
| 130 | Poll::Pending |
| 131 | } |
| 132 | }) |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | impl<M: RawMutex> Semaphore for GreedySemaphore<M> { |
| 137 | type Error = Infallible; |
| 138 | |
| 139 | async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { |
| 140 | poll_fn(|cx| self.poll_acquire(permits, false, Some(cx.waker()))).await |
| 141 | } |
| 142 | |
| 143 | fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| 144 | match self.poll_acquire(permits, false, None) { |
| 145 | Poll::Ready(Ok(n)) => Some(n), |
| 146 | _ => None, |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { |
| 151 | poll_fn(|cx| self.poll_acquire(min, true, Some(cx.waker()))).await |
| 152 | } |
| 153 | |
| 154 | fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| 155 | match self.poll_acquire(min, true, None) { |
| 156 | Poll::Ready(Ok(n)) => Some(n), |
| 157 | _ => None, |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | fn release(&self, permits: usize) { |
| 162 | if permits > 0 { |
| 163 | self.state.lock(|cell| { |
| 164 | let mut state = cell.replace(SemaphoreState::EMPTY); |
| 165 | state.permits += permits; |
| 166 | state.wake(); |
| 167 | cell.set(state); |
| 168 | }); |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | fn set(&self, permits: usize) { |
| 173 | self.state.lock(|cell| { |
| 174 | let mut state = cell.replace(SemaphoreState::EMPTY); |
| 175 | if permits > state.permits { |
| 176 | state.wake(); |
| 177 | } |
| 178 | state.permits = permits; |
| 179 | cell.set(state); |
| 180 | }); |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | struct SemaphoreState { |
| 185 | permits: usize, |
| 186 | waker: WakerRegistration, |
| 187 | } |
| 188 | |
| 189 | impl SemaphoreState { |
| 190 | const EMPTY: SemaphoreState = SemaphoreState { |
| 191 | permits: 0, |
| 192 | waker: WakerRegistration::new(), |
| 193 | }; |
| 194 | |
| 195 | fn register(&mut self, w: &Waker) { |
| 196 | self.waker.register(w); |
| 197 | } |
| 198 | |
| 199 | fn take(&mut self, mut permits: usize, acquire_all: bool) -> Option<usize> { |
| 200 | if self.permits < permits { |
| 201 | None |
| 202 | } else { |
| 203 | if acquire_all { |
| 204 | permits = self.permits; |
| 205 | } |
| 206 | self.permits -= permits; |
| 207 | Some(permits) |
| 208 | } |
| 209 | } |
| 210 | |
| 211 | fn wake(&mut self) { |
| 212 | self.waker.wake(); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | /// A fair [`Semaphore`] implementation. |
| 217 | /// |
| 218 | /// Tasks are allowed to acquire permits in FIFO order. A task waiting to acquire |
| 219 | /// a large number of permits will prevent other tasks from acquiring any permits |
| 220 | /// until its request is satisfied. |
| 221 | /// |
| 222 | /// Up to `N` tasks may attempt to acquire permits concurrently. If additional |
| 223 | /// tasks attempt to acquire a permit, a [`WaitQueueFull`] error will be returned. |
| 224 | pub struct FairSemaphore<M, const N: usize> |
| 225 | where |
| 226 | M: RawMutex, |
| 227 | { |
| 228 | state: Mutex<M, RefCell<FairSemaphoreState<N>>>, |
| 229 | } |
| 230 | |
| 231 | impl<M, const N: usize> Default for FairSemaphore<M, N> |
| 232 | where |
| 233 | M: RawMutex, |
| 234 | { |
| 235 | fn default() -> Self { |
| 236 | Self::new(permits:0) |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | impl<M, const N: usize> FairSemaphore<M, N> |
| 241 | where |
| 242 | M: RawMutex, |
| 243 | { |
| 244 | /// Create a new `FairSemaphore`. |
| 245 | pub const fn new(permits: usize) -> Self { |
| 246 | Self { |
| 247 | state: Mutex::new(RefCell::new(FairSemaphoreState::new(permits))), |
| 248 | } |
| 249 | } |
| 250 | |
| 251 | #[cfg (test)] |
| 252 | fn permits(&self) -> usize { |
| 253 | self.state.lock(|cell| cell.borrow().permits) |
| 254 | } |
| 255 | |
| 256 | fn poll_acquire( |
| 257 | &self, |
| 258 | permits: usize, |
| 259 | acquire_all: bool, |
| 260 | cx: Option<(&mut Option<usize>, &Waker)>, |
| 261 | ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> { |
| 262 | let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None); |
| 263 | self.state.lock(|cell| { |
| 264 | let mut state = cell.borrow_mut(); |
| 265 | if let Some(permits) = state.take(ticket, permits, acquire_all) { |
| 266 | Poll::Ready(Ok(SemaphoreReleaser { |
| 267 | semaphore: self, |
| 268 | permits, |
| 269 | })) |
| 270 | } else if let Some((ticket_ref, waker)) = cx { |
| 271 | match state.register(ticket, waker) { |
| 272 | Ok(ticket) => { |
| 273 | *ticket_ref = Some(ticket); |
| 274 | Poll::Pending |
| 275 | } |
| 276 | Err(err) => Poll::Ready(Err(err)), |
| 277 | } |
| 278 | } else { |
| 279 | Poll::Pending |
| 280 | } |
| 281 | }) |
| 282 | } |
| 283 | } |
| 284 | |
| 285 | /// An error indicating the [`FairSemaphore`]'s wait queue is full. |
| 286 | #[derive (Debug, Clone, Copy, PartialEq, Eq)] |
| 287 | #[cfg_attr (feature = "defmt" , derive(defmt::Format))] |
| 288 | pub struct WaitQueueFull; |
| 289 | |
| 290 | impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { |
| 291 | type Error = WaitQueueFull; |
| 292 | |
| 293 | fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { |
| 294 | FairAcquire { |
| 295 | sema: self, |
| 296 | permits, |
| 297 | ticket: None, |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| 302 | match self.poll_acquire(permits, false, None) { |
| 303 | Poll::Ready(Ok(x)) => Some(x), |
| 304 | _ => None, |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { |
| 309 | FairAcquireAll { |
| 310 | sema: self, |
| 311 | min, |
| 312 | ticket: None, |
| 313 | } |
| 314 | } |
| 315 | |
| 316 | fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| 317 | match self.poll_acquire(min, true, None) { |
| 318 | Poll::Ready(Ok(x)) => Some(x), |
| 319 | _ => None, |
| 320 | } |
| 321 | } |
| 322 | |
| 323 | fn release(&self, permits: usize) { |
| 324 | if permits > 0 { |
| 325 | self.state.lock(|cell| { |
| 326 | let mut state = cell.borrow_mut(); |
| 327 | state.permits += permits; |
| 328 | state.wake(); |
| 329 | }); |
| 330 | } |
| 331 | } |
| 332 | |
| 333 | fn set(&self, permits: usize) { |
| 334 | self.state.lock(|cell| { |
| 335 | let mut state = cell.borrow_mut(); |
| 336 | if permits > state.permits { |
| 337 | state.wake(); |
| 338 | } |
| 339 | state.permits = permits; |
| 340 | }); |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | struct FairAcquire<'a, M: RawMutex, const N: usize> { |
| 345 | sema: &'a FairSemaphore<M, N>, |
| 346 | permits: usize, |
| 347 | ticket: Option<usize>, |
| 348 | } |
| 349 | |
| 350 | impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> { |
| 351 | fn drop(&mut self) { |
| 352 | self.sema |
| 353 | .state |
| 354 | .lock(|cell: &RefCell>| cell.borrow_mut().cancel(self.ticket.take())); |
| 355 | } |
| 356 | } |
| 357 | |
| 358 | impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> { |
| 359 | type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; |
| 360 | |
| 361 | fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { |
| 362 | self.sema |
| 363 | .poll_acquire(self.permits, acquire_all:false, cx:Some((&mut self.ticket, cx.waker()))) |
| 364 | } |
| 365 | } |
| 366 | |
| 367 | struct FairAcquireAll<'a, M: RawMutex, const N: usize> { |
| 368 | sema: &'a FairSemaphore<M, N>, |
| 369 | min: usize, |
| 370 | ticket: Option<usize>, |
| 371 | } |
| 372 | |
| 373 | impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> { |
| 374 | fn drop(&mut self) { |
| 375 | self.sema |
| 376 | .state |
| 377 | .lock(|cell: &RefCell>| cell.borrow_mut().cancel(self.ticket.take())); |
| 378 | } |
| 379 | } |
| 380 | |
| 381 | impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> { |
| 382 | type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; |
| 383 | |
| 384 | fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { |
| 385 | self.sema |
| 386 | .poll_acquire(self.min, acquire_all:true, cx:Some((&mut self.ticket, cx.waker()))) |
| 387 | } |
| 388 | } |
| 389 | |
| 390 | struct FairSemaphoreState<const N: usize> { |
| 391 | permits: usize, |
| 392 | next_ticket: usize, |
| 393 | wakers: Deque<Option<Waker>, N>, |
| 394 | } |
| 395 | |
| 396 | impl<const N: usize> FairSemaphoreState<N> { |
| 397 | /// Create a new empty instance |
| 398 | const fn new(permits: usize) -> Self { |
| 399 | Self { |
| 400 | permits, |
| 401 | next_ticket: 0, |
| 402 | wakers: Deque::new(), |
| 403 | } |
| 404 | } |
| 405 | |
| 406 | /// Register a waker. If the queue is full the function returns an error |
| 407 | fn register(&mut self, ticket: Option<usize>, w: &Waker) -> Result<usize, WaitQueueFull> { |
| 408 | self.pop_canceled(); |
| 409 | |
| 410 | match ticket { |
| 411 | None => { |
| 412 | let ticket = self.next_ticket.wrapping_add(self.wakers.len()); |
| 413 | self.wakers.push_back(Some(w.clone())).or(Err(WaitQueueFull))?; |
| 414 | Ok(ticket) |
| 415 | } |
| 416 | Some(ticket) => { |
| 417 | self.set_waker(ticket, Some(w.clone())); |
| 418 | Ok(ticket) |
| 419 | } |
| 420 | } |
| 421 | } |
| 422 | |
| 423 | fn cancel(&mut self, ticket: Option<usize>) { |
| 424 | if let Some(ticket) = ticket { |
| 425 | self.set_waker(ticket, None); |
| 426 | } |
| 427 | } |
| 428 | |
| 429 | fn set_waker(&mut self, ticket: usize, waker: Option<Waker>) { |
| 430 | let i = ticket.wrapping_sub(self.next_ticket); |
| 431 | if i < self.wakers.len() { |
| 432 | let (a, b) = self.wakers.as_mut_slices(); |
| 433 | let x = if i < a.len() { &mut a[i] } else { &mut b[i - a.len()] }; |
| 434 | *x = waker; |
| 435 | } |
| 436 | } |
| 437 | |
| 438 | fn take(&mut self, ticket: Option<usize>, mut permits: usize, acquire_all: bool) -> Option<usize> { |
| 439 | self.pop_canceled(); |
| 440 | |
| 441 | if permits > self.permits { |
| 442 | return None; |
| 443 | } |
| 444 | |
| 445 | match ticket { |
| 446 | Some(n) if n != self.next_ticket => return None, |
| 447 | None if !self.wakers.is_empty() => return None, |
| 448 | _ => (), |
| 449 | } |
| 450 | |
| 451 | if acquire_all { |
| 452 | permits = self.permits; |
| 453 | } |
| 454 | self.permits -= permits; |
| 455 | |
| 456 | if ticket.is_some() { |
| 457 | self.pop(); |
| 458 | if self.permits > 0 { |
| 459 | self.wake(); |
| 460 | } |
| 461 | } |
| 462 | |
| 463 | Some(permits) |
| 464 | } |
| 465 | |
| 466 | fn pop_canceled(&mut self) { |
| 467 | while let Some(None) = self.wakers.front() { |
| 468 | self.pop(); |
| 469 | } |
| 470 | } |
| 471 | |
| 472 | /// Panics if `self.wakers` is empty |
| 473 | fn pop(&mut self) { |
| 474 | self.wakers.pop_front().unwrap(); |
| 475 | self.next_ticket = self.next_ticket.wrapping_add(1); |
| 476 | } |
| 477 | |
| 478 | fn wake(&mut self) { |
| 479 | self.pop_canceled(); |
| 480 | |
| 481 | if let Some(Some(waker)) = self.wakers.front() { |
| 482 | waker.wake_by_ref(); |
| 483 | } |
| 484 | } |
| 485 | } |
| 486 | |
| 487 | #[cfg (test)] |
| 488 | mod tests { |
| 489 | mod greedy { |
| 490 | use core::pin::pin; |
| 491 | |
| 492 | use futures_util::poll; |
| 493 | |
| 494 | use super::super::*; |
| 495 | use crate::blocking_mutex::raw::NoopRawMutex; |
| 496 | |
| 497 | #[test ] |
| 498 | fn try_acquire() { |
| 499 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 500 | |
| 501 | let a = semaphore.try_acquire(1).unwrap(); |
| 502 | assert_eq!(a.permits(), 1); |
| 503 | assert_eq!(semaphore.permits(), 2); |
| 504 | |
| 505 | core::mem::drop(a); |
| 506 | assert_eq!(semaphore.permits(), 3); |
| 507 | } |
| 508 | |
| 509 | #[test ] |
| 510 | fn disarm() { |
| 511 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 512 | |
| 513 | let a = semaphore.try_acquire(1).unwrap(); |
| 514 | assert_eq!(a.disarm(), 1); |
| 515 | assert_eq!(semaphore.permits(), 2); |
| 516 | } |
| 517 | |
| 518 | #[futures_test::test] |
| 519 | async fn acquire() { |
| 520 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 521 | |
| 522 | let a = semaphore.acquire(1).await.unwrap(); |
| 523 | assert_eq!(a.permits(), 1); |
| 524 | assert_eq!(semaphore.permits(), 2); |
| 525 | |
| 526 | core::mem::drop(a); |
| 527 | assert_eq!(semaphore.permits(), 3); |
| 528 | } |
| 529 | |
| 530 | #[test ] |
| 531 | fn try_acquire_all() { |
| 532 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 533 | |
| 534 | let a = semaphore.try_acquire_all(1).unwrap(); |
| 535 | assert_eq!(a.permits(), 3); |
| 536 | assert_eq!(semaphore.permits(), 0); |
| 537 | } |
| 538 | |
| 539 | #[futures_test::test] |
| 540 | async fn acquire_all() { |
| 541 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 542 | |
| 543 | let a = semaphore.acquire_all(1).await.unwrap(); |
| 544 | assert_eq!(a.permits(), 3); |
| 545 | assert_eq!(semaphore.permits(), 0); |
| 546 | } |
| 547 | |
| 548 | #[test ] |
| 549 | fn release() { |
| 550 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 551 | assert_eq!(semaphore.permits(), 3); |
| 552 | semaphore.release(2); |
| 553 | assert_eq!(semaphore.permits(), 5); |
| 554 | } |
| 555 | |
| 556 | #[test ] |
| 557 | fn set() { |
| 558 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 559 | assert_eq!(semaphore.permits(), 3); |
| 560 | semaphore.set(2); |
| 561 | assert_eq!(semaphore.permits(), 2); |
| 562 | } |
| 563 | |
| 564 | #[test ] |
| 565 | fn contested() { |
| 566 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 567 | |
| 568 | let a = semaphore.try_acquire(1).unwrap(); |
| 569 | let b = semaphore.try_acquire(3); |
| 570 | assert!(b.is_none()); |
| 571 | |
| 572 | core::mem::drop(a); |
| 573 | |
| 574 | let b = semaphore.try_acquire(3); |
| 575 | assert!(b.is_some()); |
| 576 | } |
| 577 | |
| 578 | #[futures_test::test] |
| 579 | async fn greedy() { |
| 580 | let semaphore = GreedySemaphore::<NoopRawMutex>::new(3); |
| 581 | |
| 582 | let a = semaphore.try_acquire(1).unwrap(); |
| 583 | |
| 584 | let b_fut = semaphore.acquire(3); |
| 585 | let mut b_fut = pin!(b_fut); |
| 586 | let b = poll!(b_fut.as_mut()); |
| 587 | assert!(b.is_pending()); |
| 588 | |
| 589 | // Succeed even through `b` is waiting |
| 590 | let c = semaphore.try_acquire(1); |
| 591 | assert!(c.is_some()); |
| 592 | |
| 593 | let b = poll!(b_fut.as_mut()); |
| 594 | assert!(b.is_pending()); |
| 595 | |
| 596 | core::mem::drop(a); |
| 597 | |
| 598 | let b = poll!(b_fut.as_mut()); |
| 599 | assert!(b.is_pending()); |
| 600 | |
| 601 | core::mem::drop(c); |
| 602 | |
| 603 | let b = poll!(b_fut.as_mut()); |
| 604 | assert!(b.is_ready()); |
| 605 | } |
| 606 | } |
| 607 | |
| 608 | mod fair { |
| 609 | use core::pin::pin; |
| 610 | use core::time::Duration; |
| 611 | |
| 612 | use futures_executor::ThreadPool; |
| 613 | use futures_timer::Delay; |
| 614 | use futures_util::poll; |
| 615 | use futures_util::task::SpawnExt; |
| 616 | use static_cell::StaticCell; |
| 617 | |
| 618 | use super::super::*; |
| 619 | use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; |
| 620 | |
| 621 | #[test ] |
| 622 | fn try_acquire() { |
| 623 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 624 | |
| 625 | let a = semaphore.try_acquire(1).unwrap(); |
| 626 | assert_eq!(a.permits(), 1); |
| 627 | assert_eq!(semaphore.permits(), 2); |
| 628 | |
| 629 | core::mem::drop(a); |
| 630 | assert_eq!(semaphore.permits(), 3); |
| 631 | } |
| 632 | |
| 633 | #[test ] |
| 634 | fn disarm() { |
| 635 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 636 | |
| 637 | let a = semaphore.try_acquire(1).unwrap(); |
| 638 | assert_eq!(a.disarm(), 1); |
| 639 | assert_eq!(semaphore.permits(), 2); |
| 640 | } |
| 641 | |
| 642 | #[futures_test::test] |
| 643 | async fn acquire() { |
| 644 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 645 | |
| 646 | let a = semaphore.acquire(1).await.unwrap(); |
| 647 | assert_eq!(a.permits(), 1); |
| 648 | assert_eq!(semaphore.permits(), 2); |
| 649 | |
| 650 | core::mem::drop(a); |
| 651 | assert_eq!(semaphore.permits(), 3); |
| 652 | } |
| 653 | |
| 654 | #[test ] |
| 655 | fn try_acquire_all() { |
| 656 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 657 | |
| 658 | let a = semaphore.try_acquire_all(1).unwrap(); |
| 659 | assert_eq!(a.permits(), 3); |
| 660 | assert_eq!(semaphore.permits(), 0); |
| 661 | } |
| 662 | |
| 663 | #[futures_test::test] |
| 664 | async fn acquire_all() { |
| 665 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 666 | |
| 667 | let a = semaphore.acquire_all(1).await.unwrap(); |
| 668 | assert_eq!(a.permits(), 3); |
| 669 | assert_eq!(semaphore.permits(), 0); |
| 670 | } |
| 671 | |
| 672 | #[test ] |
| 673 | fn release() { |
| 674 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 675 | assert_eq!(semaphore.permits(), 3); |
| 676 | semaphore.release(2); |
| 677 | assert_eq!(semaphore.permits(), 5); |
| 678 | } |
| 679 | |
| 680 | #[test ] |
| 681 | fn set() { |
| 682 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 683 | assert_eq!(semaphore.permits(), 3); |
| 684 | semaphore.set(2); |
| 685 | assert_eq!(semaphore.permits(), 2); |
| 686 | } |
| 687 | |
| 688 | #[test ] |
| 689 | fn contested() { |
| 690 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 691 | |
| 692 | let a = semaphore.try_acquire(1).unwrap(); |
| 693 | let b = semaphore.try_acquire(3); |
| 694 | assert!(b.is_none()); |
| 695 | |
| 696 | core::mem::drop(a); |
| 697 | |
| 698 | let b = semaphore.try_acquire(3); |
| 699 | assert!(b.is_some()); |
| 700 | } |
| 701 | |
| 702 | #[futures_test::test] |
| 703 | async fn fairness() { |
| 704 | let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3); |
| 705 | |
| 706 | let a = semaphore.try_acquire(1); |
| 707 | assert!(a.is_some()); |
| 708 | |
| 709 | let b_fut = semaphore.acquire(3); |
| 710 | let mut b_fut = pin!(b_fut); |
| 711 | let b = poll!(b_fut.as_mut()); // Poll `b_fut` once so it is registered |
| 712 | assert!(b.is_pending()); |
| 713 | |
| 714 | let c = semaphore.try_acquire(1); |
| 715 | assert!(c.is_none()); |
| 716 | |
| 717 | let c_fut = semaphore.acquire(1); |
| 718 | let mut c_fut = pin!(c_fut); |
| 719 | let c = poll!(c_fut.as_mut()); // Poll `c_fut` once so it is registered |
| 720 | assert!(c.is_pending()); // `c` is blocked behind `b` |
| 721 | |
| 722 | let d = semaphore.acquire(1).await; |
| 723 | assert!(matches!(d, Err(WaitQueueFull))); |
| 724 | |
| 725 | core::mem::drop(a); |
| 726 | |
| 727 | let c = poll!(c_fut.as_mut()); |
| 728 | assert!(c.is_pending()); // `c` is still blocked behind `b` |
| 729 | |
| 730 | let b = poll!(b_fut.as_mut()); |
| 731 | assert!(b.is_ready()); |
| 732 | |
| 733 | let c = poll!(c_fut.as_mut()); |
| 734 | assert!(c.is_pending()); // `c` is still blocked behind `b` |
| 735 | |
| 736 | core::mem::drop(b); |
| 737 | |
| 738 | let c = poll!(c_fut.as_mut()); |
| 739 | assert!(c.is_ready()); |
| 740 | } |
| 741 | |
| 742 | #[futures_test::test] |
| 743 | async fn wakers() { |
| 744 | let executor = ThreadPool::new().unwrap(); |
| 745 | |
| 746 | static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new(); |
| 747 | let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3)); |
| 748 | |
| 749 | let a = semaphore.try_acquire(2); |
| 750 | assert!(a.is_some()); |
| 751 | |
| 752 | let b_task = executor |
| 753 | .spawn_with_handle(async move { semaphore.acquire(2).await }) |
| 754 | .unwrap(); |
| 755 | while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) { |
| 756 | Delay::new(Duration::from_millis(50)).await; |
| 757 | } |
| 758 | |
| 759 | let c_task = executor |
| 760 | .spawn_with_handle(async move { semaphore.acquire(1).await }) |
| 761 | .unwrap(); |
| 762 | |
| 763 | core::mem::drop(a); |
| 764 | |
| 765 | let b = b_task.await.unwrap(); |
| 766 | assert_eq!(b.permits(), 2); |
| 767 | |
| 768 | let c = c_task.await.unwrap(); |
| 769 | assert_eq!(c.permits(), 1); |
| 770 | } |
| 771 | } |
| 772 | } |
| 773 | |