1use std::future::Future;
2use std::sync::Arc;
3use std::task::Poll;
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5use tokio_util::sync::PollSemaphore;
6
7type SemRet = Option<OwnedSemaphorePermit>;
8
9fn semaphore_poll(
10 sem: &mut PollSemaphore,
11) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> {
12 let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx));
13 tokio_test::task::spawn(fut)
14}
15
16fn semaphore_poll_many(
17 sem: &mut PollSemaphore,
18 permits: u32,
19) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> {
20 let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits));
21 tokio_test::task::spawn(fut)
22}
23
24#[tokio::test]
25async fn it_works() {
26 let sem = Arc::new(Semaphore::new(1));
27 let mut poll_sem = PollSemaphore::new(sem.clone());
28
29 let permit = sem.acquire().await.unwrap();
30 let mut poll = semaphore_poll(&mut poll_sem);
31 assert!(poll.poll().is_pending());
32 drop(permit);
33
34 assert!(matches!(poll.poll(), Poll::Ready(Some(_))));
35 drop(poll);
36
37 sem.close();
38
39 assert!(semaphore_poll(&mut poll_sem).await.is_none());
40
41 // Check that it is fused.
42 assert!(semaphore_poll(&mut poll_sem).await.is_none());
43 assert!(semaphore_poll(&mut poll_sem).await.is_none());
44}
45
46#[tokio::test]
47async fn can_acquire_many_permits() {
48 let sem = Arc::new(Semaphore::new(4));
49 let mut poll_sem = PollSemaphore::new(sem.clone());
50
51 let permit1 = semaphore_poll(&mut poll_sem).poll();
52 assert!(matches!(permit1, Poll::Ready(Some(_))));
53
54 let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll();
55 assert!(matches!(permit2, Poll::Ready(Some(_))));
56
57 assert_eq!(sem.available_permits(), 1);
58
59 drop(permit2);
60
61 let mut permit4 = semaphore_poll_many(&mut poll_sem, 4);
62 assert!(permit4.poll().is_pending());
63
64 drop(permit1);
65
66 let permit4 = permit4.poll();
67 assert!(matches!(permit4, Poll::Ready(Some(_))));
68 assert_eq!(sem.available_permits(), 0);
69}
70
71#[tokio::test]
72async fn can_poll_different_amounts_of_permits() {
73 let sem = Arc::new(Semaphore::new(4));
74 let mut poll_sem = PollSemaphore::new(sem.clone());
75 assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
76 assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());
77
78 let permit = sem.acquire_many(4).await.unwrap();
79 assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
80 assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending());
81 drop(permit);
82 assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
83 assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());
84}
85