1 | use crate::codec::Codec; |
2 | use crate::frame::Ping; |
3 | use crate::proto::{self, PingPayload}; |
4 | |
5 | use bytes::Buf; |
6 | use futures_util::task::AtomicWaker; |
7 | use std::io; |
8 | use std::sync::atomic::{AtomicUsize, Ordering}; |
9 | use std::sync::Arc; |
10 | use std::task::{Context, Poll}; |
11 | use tokio::io::AsyncWrite; |
12 | |
13 | /// Acknowledges ping requests from the remote. |
14 | #[derive (Debug)] |
15 | pub(crate) struct PingPong { |
16 | pending_ping: Option<PendingPing>, |
17 | pending_pong: Option<PingPayload>, |
18 | user_pings: Option<UserPingsRx>, |
19 | } |
20 | |
21 | #[derive (Debug)] |
22 | pub(crate) struct UserPings(Arc<UserPingsInner>); |
23 | |
24 | #[derive (Debug)] |
25 | struct UserPingsRx(Arc<UserPingsInner>); |
26 | |
27 | #[derive (Debug)] |
28 | struct UserPingsInner { |
29 | state: AtomicUsize, |
30 | /// Task to wake up the main `Connection`. |
31 | ping_task: AtomicWaker, |
32 | /// Task to wake up `share::PingPong::poll_pong`. |
33 | pong_task: AtomicWaker, |
34 | } |
35 | |
36 | #[derive (Debug)] |
37 | struct PendingPing { |
38 | payload: PingPayload, |
39 | sent: bool, |
40 | } |
41 | |
42 | /// Status returned from `PingPong::recv_ping`. |
43 | #[derive (Debug)] |
44 | pub(crate) enum ReceivedPing { |
45 | MustAck, |
46 | Unknown, |
47 | Shutdown, |
48 | } |
49 | |
50 | /// No user ping pending. |
51 | const USER_STATE_EMPTY: usize = 0; |
52 | /// User has called `send_ping`, but PING hasn't been written yet. |
53 | const USER_STATE_PENDING_PING: usize = 1; |
54 | /// User PING has been written, waiting for PONG. |
55 | const USER_STATE_PENDING_PONG: usize = 2; |
56 | /// We've received user PONG, waiting for user to `poll_pong`. |
57 | const USER_STATE_RECEIVED_PONG: usize = 3; |
58 | /// The connection is closed. |
59 | const USER_STATE_CLOSED: usize = 4; |
60 | |
61 | // ===== impl PingPong ===== |
62 | |
63 | impl PingPong { |
64 | pub(crate) fn new() -> Self { |
65 | PingPong { |
66 | pending_ping: None, |
67 | pending_pong: None, |
68 | user_pings: None, |
69 | } |
70 | } |
71 | |
72 | /// Can only be called once. If called a second time, returns `None`. |
73 | pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> { |
74 | if self.user_pings.is_some() { |
75 | return None; |
76 | } |
77 | |
78 | let user_pings = Arc::new(UserPingsInner { |
79 | state: AtomicUsize::new(USER_STATE_EMPTY), |
80 | ping_task: AtomicWaker::new(), |
81 | pong_task: AtomicWaker::new(), |
82 | }); |
83 | self.user_pings = Some(UserPingsRx(user_pings.clone())); |
84 | Some(UserPings(user_pings)) |
85 | } |
86 | |
87 | pub(crate) fn ping_shutdown(&mut self) { |
88 | assert!(self.pending_ping.is_none()); |
89 | |
90 | self.pending_ping = Some(PendingPing { |
91 | payload: Ping::SHUTDOWN, |
92 | sent: false, |
93 | }); |
94 | } |
95 | |
96 | /// Process a ping |
97 | pub(crate) fn recv_ping(&mut self, ping: Ping) -> ReceivedPing { |
98 | // The caller should always check that `send_pongs` returns ready before |
99 | // calling `recv_ping`. |
100 | assert!(self.pending_pong.is_none()); |
101 | |
102 | if ping.is_ack() { |
103 | if let Some(pending) = self.pending_ping.take() { |
104 | if &pending.payload == ping.payload() { |
105 | assert_eq!( |
106 | &pending.payload, |
107 | &Ping::SHUTDOWN, |
108 | "pending_ping should be for shutdown" , |
109 | ); |
110 | tracing::trace!("recv PING SHUTDOWN ack" ); |
111 | return ReceivedPing::Shutdown; |
112 | } |
113 | |
114 | // if not the payload we expected, put it back. |
115 | self.pending_ping = Some(pending); |
116 | } |
117 | |
118 | if let Some(ref users) = self.user_pings { |
119 | if ping.payload() == &Ping::USER && users.receive_pong() { |
120 | tracing::trace!("recv PING USER ack" ); |
121 | return ReceivedPing::Unknown; |
122 | } |
123 | } |
124 | |
125 | // else we were acked a ping we didn't send? |
126 | // The spec doesn't require us to do anything about this, |
127 | // so for resiliency, just ignore it for now. |
128 | tracing::warn!("recv PING ack that we never sent: {:?}" , ping); |
129 | ReceivedPing::Unknown |
130 | } else { |
131 | // Save the ping's payload to be sent as an acknowledgement. |
132 | self.pending_pong = Some(ping.into_payload()); |
133 | ReceivedPing::MustAck |
134 | } |
135 | } |
136 | |
137 | /// Send any pending pongs. |
138 | pub(crate) fn send_pending_pong<T, B>( |
139 | &mut self, |
140 | cx: &mut Context, |
141 | dst: &mut Codec<T, B>, |
142 | ) -> Poll<io::Result<()>> |
143 | where |
144 | T: AsyncWrite + Unpin, |
145 | B: Buf, |
146 | { |
147 | if let Some(pong) = self.pending_pong.take() { |
148 | if !dst.poll_ready(cx)?.is_ready() { |
149 | self.pending_pong = Some(pong); |
150 | return Poll::Pending; |
151 | } |
152 | |
153 | dst.buffer(Ping::pong(pong).into()) |
154 | .expect("invalid pong frame" ); |
155 | } |
156 | |
157 | Poll::Ready(Ok(())) |
158 | } |
159 | |
160 | /// Send any pending pings. |
161 | pub(crate) fn send_pending_ping<T, B>( |
162 | &mut self, |
163 | cx: &mut Context, |
164 | dst: &mut Codec<T, B>, |
165 | ) -> Poll<io::Result<()>> |
166 | where |
167 | T: AsyncWrite + Unpin, |
168 | B: Buf, |
169 | { |
170 | if let Some(ref mut ping) = self.pending_ping { |
171 | if !ping.sent { |
172 | if !dst.poll_ready(cx)?.is_ready() { |
173 | return Poll::Pending; |
174 | } |
175 | |
176 | dst.buffer(Ping::new(ping.payload).into()) |
177 | .expect("invalid ping frame" ); |
178 | ping.sent = true; |
179 | } |
180 | } else if let Some(ref users) = self.user_pings { |
181 | if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING { |
182 | if !dst.poll_ready(cx)?.is_ready() { |
183 | return Poll::Pending; |
184 | } |
185 | |
186 | dst.buffer(Ping::new(Ping::USER).into()) |
187 | .expect("invalid ping frame" ); |
188 | users |
189 | .0 |
190 | .state |
191 | .store(USER_STATE_PENDING_PONG, Ordering::Release); |
192 | } else { |
193 | users.0.ping_task.register(cx.waker()); |
194 | } |
195 | } |
196 | |
197 | Poll::Ready(Ok(())) |
198 | } |
199 | } |
200 | |
201 | impl ReceivedPing { |
202 | pub(crate) fn is_shutdown(&self) -> bool { |
203 | matches!(*self, Self::Shutdown) |
204 | } |
205 | } |
206 | |
207 | // ===== impl UserPings ===== |
208 | |
209 | impl UserPings { |
210 | pub(crate) fn send_ping(&self) -> Result<(), Option<proto::Error>> { |
211 | let prev = self |
212 | .0 |
213 | .state |
214 | .compare_exchange( |
215 | USER_STATE_EMPTY, // current |
216 | USER_STATE_PENDING_PING, // new |
217 | Ordering::AcqRel, |
218 | Ordering::Acquire, |
219 | ) |
220 | .unwrap_or_else(|v| v); |
221 | |
222 | match prev { |
223 | USER_STATE_EMPTY => { |
224 | self.0.ping_task.wake(); |
225 | Ok(()) |
226 | } |
227 | USER_STATE_CLOSED => Err(Some(broken_pipe().into())), |
228 | _ => { |
229 | // Was already pending, user error! |
230 | Err(None) |
231 | } |
232 | } |
233 | } |
234 | |
235 | pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>> { |
236 | // Must register before checking state, in case state were to change |
237 | // before we could register, and then the ping would just be lost. |
238 | self.0.pong_task.register(cx.waker()); |
239 | let prev = self |
240 | .0 |
241 | .state |
242 | .compare_exchange( |
243 | USER_STATE_RECEIVED_PONG, // current |
244 | USER_STATE_EMPTY, // new |
245 | Ordering::AcqRel, |
246 | Ordering::Acquire, |
247 | ) |
248 | .unwrap_or_else(|v| v); |
249 | |
250 | match prev { |
251 | USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())), |
252 | USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())), |
253 | _ => Poll::Pending, |
254 | } |
255 | } |
256 | } |
257 | |
258 | // ===== impl UserPingsRx ===== |
259 | |
260 | impl UserPingsRx { |
261 | fn receive_pong(&self) -> bool { |
262 | let prev: usize = self |
263 | .0 |
264 | .state |
265 | .compare_exchange( |
266 | USER_STATE_PENDING_PONG, // current |
267 | USER_STATE_RECEIVED_PONG, // new |
268 | Ordering::AcqRel, |
269 | Ordering::Acquire, |
270 | ) |
271 | .unwrap_or_else(|v: usize| v); |
272 | |
273 | if prev == USER_STATE_PENDING_PONG { |
274 | self.0.pong_task.wake(); |
275 | true |
276 | } else { |
277 | false |
278 | } |
279 | } |
280 | } |
281 | |
282 | impl Drop for UserPingsRx { |
283 | fn drop(&mut self) { |
284 | self.0.state.store(USER_STATE_CLOSED, order:Ordering::Release); |
285 | self.0.pong_task.wake(); |
286 | } |
287 | } |
288 | |
289 | fn broken_pipe() -> io::Error { |
290 | io::ErrorKind::BrokenPipe.into() |
291 | } |
292 | |