1 | use futures_core::{ready, Stream}; |
2 | use std::fmt; |
3 | use std::pin::Pin; |
4 | use std::sync::Arc; |
5 | use std::task::{Context, Poll}; |
6 | use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; |
7 | |
8 | use super::ReusableBoxFuture; |
9 | |
10 | /// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. |
11 | /// |
12 | /// [`Semaphore`]: tokio::sync::Semaphore |
13 | pub struct PollSemaphore { |
14 | semaphore: Arc<Semaphore>, |
15 | permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>, |
16 | } |
17 | |
18 | impl PollSemaphore { |
19 | /// Create a new `PollSemaphore`. |
20 | pub fn new(semaphore: Arc<Semaphore>) -> Self { |
21 | Self { |
22 | semaphore, |
23 | permit_fut: None, |
24 | } |
25 | } |
26 | |
27 | /// Closes the semaphore. |
28 | pub fn close(&self) { |
29 | self.semaphore.close() |
30 | } |
31 | |
32 | /// Obtain a clone of the inner semaphore. |
33 | pub fn clone_inner(&self) -> Arc<Semaphore> { |
34 | self.semaphore.clone() |
35 | } |
36 | |
37 | /// Get back the inner semaphore. |
38 | pub fn into_inner(self) -> Arc<Semaphore> { |
39 | self.semaphore |
40 | } |
41 | |
42 | /// Poll to acquire a permit from the semaphore. |
43 | /// |
44 | /// This can return the following values: |
45 | /// |
46 | /// - `Poll::Pending` if a permit is not currently available. |
47 | /// - `Poll::Ready(Some(permit))` if a permit was acquired. |
48 | /// - `Poll::Ready(None)` if the semaphore has been closed. |
49 | /// |
50 | /// When this method returns `Poll::Pending`, the current task is scheduled |
51 | /// to receive a wakeup when a permit becomes available, or when the |
52 | /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only |
53 | /// the `Waker` from the `Context` passed to the most recent call is |
54 | /// scheduled to receive a wakeup. |
55 | pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { |
56 | let permit_future = match self.permit_fut.as_mut() { |
57 | Some(fut) => fut, |
58 | None => { |
59 | // avoid allocations completely if we can grab a permit immediately |
60 | match Arc::clone(&self.semaphore).try_acquire_owned() { |
61 | Ok(permit) => return Poll::Ready(Some(permit)), |
62 | Err(TryAcquireError::Closed) => return Poll::Ready(None), |
63 | Err(TryAcquireError::NoPermits) => {} |
64 | } |
65 | |
66 | let next_fut = Arc::clone(&self.semaphore).acquire_owned(); |
67 | self.permit_fut |
68 | .get_or_insert(ReusableBoxFuture::new(next_fut)) |
69 | } |
70 | }; |
71 | |
72 | let result = ready!(permit_future.poll(cx)); |
73 | |
74 | let next_fut = Arc::clone(&self.semaphore).acquire_owned(); |
75 | permit_future.set(next_fut); |
76 | |
77 | match result { |
78 | Ok(permit) => Poll::Ready(Some(permit)), |
79 | Err(_closed) => { |
80 | self.permit_fut = None; |
81 | Poll::Ready(None) |
82 | } |
83 | } |
84 | } |
85 | |
86 | /// Returns the current number of available permits. |
87 | /// |
88 | /// This is equivalent to the [`Semaphore::available_permits`] method on the |
89 | /// `tokio::sync::Semaphore` type. |
90 | /// |
91 | /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits |
92 | pub fn available_permits(&self) -> usize { |
93 | self.semaphore.available_permits() |
94 | } |
95 | |
96 | /// Adds `n` new permits to the semaphore. |
97 | /// |
98 | /// The maximum number of permits is `usize::MAX >> 3`, and this function |
99 | /// will panic if the limit is exceeded. |
100 | /// |
101 | /// This is equivalent to the [`Semaphore::add_permits`] method on the |
102 | /// `tokio::sync::Semaphore` type. |
103 | /// |
104 | /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits |
105 | pub fn add_permits(&self, n: usize) { |
106 | self.semaphore.add_permits(n); |
107 | } |
108 | } |
109 | |
110 | impl Stream for PollSemaphore { |
111 | type Item = OwnedSemaphorePermit; |
112 | |
113 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { |
114 | Pin::into_inner(self).poll_acquire(cx) |
115 | } |
116 | } |
117 | |
118 | impl Clone for PollSemaphore { |
119 | fn clone(&self) -> PollSemaphore { |
120 | PollSemaphore::new(self.clone_inner()) |
121 | } |
122 | } |
123 | |
124 | impl fmt::Debug for PollSemaphore { |
125 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
126 | f&mut DebugStruct<'_, '_>.debug_struct("PollSemaphore" ) |
127 | .field(name:"semaphore" , &self.semaphore) |
128 | .finish() |
129 | } |
130 | } |
131 | |
132 | impl AsRef<Semaphore> for PollSemaphore { |
133 | fn as_ref(&self) -> &Semaphore { |
134 | &*self.semaphore |
135 | } |
136 | } |
137 | |