1use crate::codec::Codec;
2use crate::frame::Ping;
3use crate::proto::{self, PingPayload};
4
5use bytes::Buf;
6use futures_util::task::AtomicWaker;
7use std::io;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::io::AsyncWrite;
12
13/// Acknowledges ping requests from the remote.
14#[derive(Debug)]
15pub(crate) struct PingPong {
16 pending_ping: Option<PendingPing>,
17 pending_pong: Option<PingPayload>,
18 user_pings: Option<UserPingsRx>,
19}
20
21#[derive(Debug)]
22pub(crate) struct UserPings(Arc<UserPingsInner>);
23
24#[derive(Debug)]
25struct UserPingsRx(Arc<UserPingsInner>);
26
27#[derive(Debug)]
28struct UserPingsInner {
29 state: AtomicUsize,
30 /// Task to wake up the main `Connection`.
31 ping_task: AtomicWaker,
32 /// Task to wake up `share::PingPong::poll_pong`.
33 pong_task: AtomicWaker,
34}
35
36#[derive(Debug)]
37struct PendingPing {
38 payload: PingPayload,
39 sent: bool,
40}
41
42/// Status returned from `PingPong::recv_ping`.
43#[derive(Debug)]
44pub(crate) enum ReceivedPing {
45 MustAck,
46 Unknown,
47 Shutdown,
48}
49
50/// No user ping pending.
51const USER_STATE_EMPTY: usize = 0;
52/// User has called `send_ping`, but PING hasn't been written yet.
53const USER_STATE_PENDING_PING: usize = 1;
54/// User PING has been written, waiting for PONG.
55const USER_STATE_PENDING_PONG: usize = 2;
56/// We've received user PONG, waiting for user to `poll_pong`.
57const USER_STATE_RECEIVED_PONG: usize = 3;
58/// The connection is closed.
59const USER_STATE_CLOSED: usize = 4;
60
61// ===== impl PingPong =====
62
63impl PingPong {
64 pub(crate) fn new() -> Self {
65 PingPong {
66 pending_ping: None,
67 pending_pong: None,
68 user_pings: None,
69 }
70 }
71
72 /// Can only be called once. If called a second time, returns `None`.
73 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
74 if self.user_pings.is_some() {
75 return None;
76 }
77
78 let user_pings = Arc::new(UserPingsInner {
79 state: AtomicUsize::new(USER_STATE_EMPTY),
80 ping_task: AtomicWaker::new(),
81 pong_task: AtomicWaker::new(),
82 });
83 self.user_pings = Some(UserPingsRx(user_pings.clone()));
84 Some(UserPings(user_pings))
85 }
86
87 pub(crate) fn ping_shutdown(&mut self) {
88 assert!(self.pending_ping.is_none());
89
90 self.pending_ping = Some(PendingPing {
91 payload: Ping::SHUTDOWN,
92 sent: false,
93 });
94 }
95
96 /// Process a ping
97 pub(crate) fn recv_ping(&mut self, ping: Ping) -> ReceivedPing {
98 // The caller should always check that `send_pongs` returns ready before
99 // calling `recv_ping`.
100 assert!(self.pending_pong.is_none());
101
102 if ping.is_ack() {
103 if let Some(pending) = self.pending_ping.take() {
104 if &pending.payload == ping.payload() {
105 assert_eq!(
106 &pending.payload,
107 &Ping::SHUTDOWN,
108 "pending_ping should be for shutdown",
109 );
110 tracing::trace!("recv PING SHUTDOWN ack");
111 return ReceivedPing::Shutdown;
112 }
113
114 // if not the payload we expected, put it back.
115 self.pending_ping = Some(pending);
116 }
117
118 if let Some(ref users) = self.user_pings {
119 if ping.payload() == &Ping::USER && users.receive_pong() {
120 tracing::trace!("recv PING USER ack");
121 return ReceivedPing::Unknown;
122 }
123 }
124
125 // else we were acked a ping we didn't send?
126 // The spec doesn't require us to do anything about this,
127 // so for resiliency, just ignore it for now.
128 tracing::warn!("recv PING ack that we never sent: {:?}", ping);
129 ReceivedPing::Unknown
130 } else {
131 // Save the ping's payload to be sent as an acknowledgement.
132 self.pending_pong = Some(ping.into_payload());
133 ReceivedPing::MustAck
134 }
135 }
136
137 /// Send any pending pongs.
138 pub(crate) fn send_pending_pong<T, B>(
139 &mut self,
140 cx: &mut Context,
141 dst: &mut Codec<T, B>,
142 ) -> Poll<io::Result<()>>
143 where
144 T: AsyncWrite + Unpin,
145 B: Buf,
146 {
147 if let Some(pong) = self.pending_pong.take() {
148 if !dst.poll_ready(cx)?.is_ready() {
149 self.pending_pong = Some(pong);
150 return Poll::Pending;
151 }
152
153 dst.buffer(Ping::pong(pong).into())
154 .expect("invalid pong frame");
155 }
156
157 Poll::Ready(Ok(()))
158 }
159
160 /// Send any pending pings.
161 pub(crate) fn send_pending_ping<T, B>(
162 &mut self,
163 cx: &mut Context,
164 dst: &mut Codec<T, B>,
165 ) -> Poll<io::Result<()>>
166 where
167 T: AsyncWrite + Unpin,
168 B: Buf,
169 {
170 if let Some(ref mut ping) = self.pending_ping {
171 if !ping.sent {
172 if !dst.poll_ready(cx)?.is_ready() {
173 return Poll::Pending;
174 }
175
176 dst.buffer(Ping::new(ping.payload).into())
177 .expect("invalid ping frame");
178 ping.sent = true;
179 }
180 } else if let Some(ref users) = self.user_pings {
181 if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING {
182 if !dst.poll_ready(cx)?.is_ready() {
183 return Poll::Pending;
184 }
185
186 dst.buffer(Ping::new(Ping::USER).into())
187 .expect("invalid ping frame");
188 users
189 .0
190 .state
191 .store(USER_STATE_PENDING_PONG, Ordering::Release);
192 } else {
193 users.0.ping_task.register(cx.waker());
194 }
195 }
196
197 Poll::Ready(Ok(()))
198 }
199}
200
201impl ReceivedPing {
202 pub(crate) fn is_shutdown(&self) -> bool {
203 matches!(*self, Self::Shutdown)
204 }
205}
206
207// ===== impl UserPings =====
208
209impl UserPings {
210 pub(crate) fn send_ping(&self) -> Result<(), Option<proto::Error>> {
211 let prev = self
212 .0
213 .state
214 .compare_exchange(
215 USER_STATE_EMPTY, // current
216 USER_STATE_PENDING_PING, // new
217 Ordering::AcqRel,
218 Ordering::Acquire,
219 )
220 .unwrap_or_else(|v| v);
221
222 match prev {
223 USER_STATE_EMPTY => {
224 self.0.ping_task.wake();
225 Ok(())
226 }
227 USER_STATE_CLOSED => Err(Some(broken_pipe().into())),
228 _ => {
229 // Was already pending, user error!
230 Err(None)
231 }
232 }
233 }
234
235 pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>> {
236 // Must register before checking state, in case state were to change
237 // before we could register, and then the ping would just be lost.
238 self.0.pong_task.register(cx.waker());
239 let prev = self
240 .0
241 .state
242 .compare_exchange(
243 USER_STATE_RECEIVED_PONG, // current
244 USER_STATE_EMPTY, // new
245 Ordering::AcqRel,
246 Ordering::Acquire,
247 )
248 .unwrap_or_else(|v| v);
249
250 match prev {
251 USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())),
252 USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())),
253 _ => Poll::Pending,
254 }
255 }
256}
257
258// ===== impl UserPingsRx =====
259
260impl UserPingsRx {
261 fn receive_pong(&self) -> bool {
262 let prev: usize = self
263 .0
264 .state
265 .compare_exchange(
266 USER_STATE_PENDING_PONG, // current
267 USER_STATE_RECEIVED_PONG, // new
268 Ordering::AcqRel,
269 Ordering::Acquire,
270 )
271 .unwrap_or_else(|v: usize| v);
272
273 if prev == USER_STATE_PENDING_PONG {
274 self.0.pong_task.wake();
275 true
276 } else {
277 false
278 }
279 }
280}
281
282impl Drop for UserPingsRx {
283 fn drop(&mut self) {
284 self.0.state.store(USER_STATE_CLOSED, order:Ordering::Release);
285 self.0.pong_task.wake();
286 }
287}
288
289fn broken_pipe() -> io::Error {
290 io::ErrorKind::BrokenPipe.into()
291}
292