| 1 | use crate::codec::Codec; |
| 2 | use crate::frame::Ping; |
| 3 | use crate::proto::{self, PingPayload}; |
| 4 | |
| 5 | use atomic_waker::AtomicWaker; |
| 6 | use bytes::Buf; |
| 7 | use std::io; |
| 8 | use std::sync::atomic::{AtomicUsize, Ordering}; |
| 9 | use std::sync::Arc; |
| 10 | use std::task::{Context, Poll}; |
| 11 | use tokio::io::AsyncWrite; |
| 12 | |
| 13 | /// Acknowledges ping requests from the remote. |
| 14 | #[derive (Debug)] |
| 15 | pub(crate) struct PingPong { |
| 16 | pending_ping: Option<PendingPing>, |
| 17 | pending_pong: Option<PingPayload>, |
| 18 | user_pings: Option<UserPingsRx>, |
| 19 | } |
| 20 | |
| 21 | #[derive (Debug)] |
| 22 | pub(crate) struct UserPings(Arc<UserPingsInner>); |
| 23 | |
| 24 | #[derive (Debug)] |
| 25 | struct UserPingsRx(Arc<UserPingsInner>); |
| 26 | |
| 27 | #[derive (Debug)] |
| 28 | struct UserPingsInner { |
| 29 | state: AtomicUsize, |
| 30 | /// Task to wake up the main `Connection`. |
| 31 | ping_task: AtomicWaker, |
| 32 | /// Task to wake up `share::PingPong::poll_pong`. |
| 33 | pong_task: AtomicWaker, |
| 34 | } |
| 35 | |
| 36 | #[derive (Debug)] |
| 37 | struct PendingPing { |
| 38 | payload: PingPayload, |
| 39 | sent: bool, |
| 40 | } |
| 41 | |
| 42 | /// Status returned from `PingPong::recv_ping`. |
| 43 | #[derive (Debug)] |
| 44 | pub(crate) enum ReceivedPing { |
| 45 | MustAck, |
| 46 | Unknown, |
| 47 | Shutdown, |
| 48 | } |
| 49 | |
| 50 | /// No user ping pending. |
| 51 | const USER_STATE_EMPTY: usize = 0; |
| 52 | /// User has called `send_ping`, but PING hasn't been written yet. |
| 53 | const USER_STATE_PENDING_PING: usize = 1; |
| 54 | /// User PING has been written, waiting for PONG. |
| 55 | const USER_STATE_PENDING_PONG: usize = 2; |
| 56 | /// We've received user PONG, waiting for user to `poll_pong`. |
| 57 | const USER_STATE_RECEIVED_PONG: usize = 3; |
| 58 | /// The connection is closed. |
| 59 | const USER_STATE_CLOSED: usize = 4; |
| 60 | |
| 61 | // ===== impl PingPong ===== |
| 62 | |
| 63 | impl PingPong { |
| 64 | pub(crate) fn new() -> Self { |
| 65 | PingPong { |
| 66 | pending_ping: None, |
| 67 | pending_pong: None, |
| 68 | user_pings: None, |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | /// Can only be called once. If called a second time, returns `None`. |
| 73 | pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> { |
| 74 | if self.user_pings.is_some() { |
| 75 | return None; |
| 76 | } |
| 77 | |
| 78 | let user_pings = Arc::new(UserPingsInner { |
| 79 | state: AtomicUsize::new(USER_STATE_EMPTY), |
| 80 | ping_task: AtomicWaker::new(), |
| 81 | pong_task: AtomicWaker::new(), |
| 82 | }); |
| 83 | self.user_pings = Some(UserPingsRx(user_pings.clone())); |
| 84 | Some(UserPings(user_pings)) |
| 85 | } |
| 86 | |
| 87 | pub(crate) fn ping_shutdown(&mut self) { |
| 88 | assert!(self.pending_ping.is_none()); |
| 89 | |
| 90 | self.pending_ping = Some(PendingPing { |
| 91 | payload: Ping::SHUTDOWN, |
| 92 | sent: false, |
| 93 | }); |
| 94 | } |
| 95 | |
| 96 | /// Process a ping |
| 97 | pub(crate) fn recv_ping(&mut self, ping: Ping) -> ReceivedPing { |
| 98 | // The caller should always check that `send_pongs` returns ready before |
| 99 | // calling `recv_ping`. |
| 100 | assert!(self.pending_pong.is_none()); |
| 101 | |
| 102 | if ping.is_ack() { |
| 103 | if let Some(pending) = self.pending_ping.take() { |
| 104 | if &pending.payload == ping.payload() { |
| 105 | assert_eq!( |
| 106 | &pending.payload, |
| 107 | &Ping::SHUTDOWN, |
| 108 | "pending_ping should be for shutdown" , |
| 109 | ); |
| 110 | tracing::trace!("recv PING SHUTDOWN ack" ); |
| 111 | return ReceivedPing::Shutdown; |
| 112 | } |
| 113 | |
| 114 | // if not the payload we expected, put it back. |
| 115 | self.pending_ping = Some(pending); |
| 116 | } |
| 117 | |
| 118 | if let Some(ref users) = self.user_pings { |
| 119 | if ping.payload() == &Ping::USER && users.receive_pong() { |
| 120 | tracing::trace!("recv PING USER ack" ); |
| 121 | return ReceivedPing::Unknown; |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | // else we were acked a ping we didn't send? |
| 126 | // The spec doesn't require us to do anything about this, |
| 127 | // so for resiliency, just ignore it for now. |
| 128 | tracing::warn!("recv PING ack that we never sent: {:?}" , ping); |
| 129 | ReceivedPing::Unknown |
| 130 | } else { |
| 131 | // Save the ping's payload to be sent as an acknowledgement. |
| 132 | self.pending_pong = Some(ping.into_payload()); |
| 133 | ReceivedPing::MustAck |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | /// Send any pending pongs. |
| 138 | pub(crate) fn send_pending_pong<T, B>( |
| 139 | &mut self, |
| 140 | cx: &mut Context, |
| 141 | dst: &mut Codec<T, B>, |
| 142 | ) -> Poll<io::Result<()>> |
| 143 | where |
| 144 | T: AsyncWrite + Unpin, |
| 145 | B: Buf, |
| 146 | { |
| 147 | if let Some(pong) = self.pending_pong.take() { |
| 148 | if !dst.poll_ready(cx)?.is_ready() { |
| 149 | self.pending_pong = Some(pong); |
| 150 | return Poll::Pending; |
| 151 | } |
| 152 | |
| 153 | dst.buffer(Ping::pong(pong).into()) |
| 154 | .expect("invalid pong frame" ); |
| 155 | } |
| 156 | |
| 157 | Poll::Ready(Ok(())) |
| 158 | } |
| 159 | |
| 160 | /// Send any pending pings. |
| 161 | pub(crate) fn send_pending_ping<T, B>( |
| 162 | &mut self, |
| 163 | cx: &mut Context, |
| 164 | dst: &mut Codec<T, B>, |
| 165 | ) -> Poll<io::Result<()>> |
| 166 | where |
| 167 | T: AsyncWrite + Unpin, |
| 168 | B: Buf, |
| 169 | { |
| 170 | if let Some(ref mut ping) = self.pending_ping { |
| 171 | if !ping.sent { |
| 172 | if !dst.poll_ready(cx)?.is_ready() { |
| 173 | return Poll::Pending; |
| 174 | } |
| 175 | |
| 176 | dst.buffer(Ping::new(ping.payload).into()) |
| 177 | .expect("invalid ping frame" ); |
| 178 | ping.sent = true; |
| 179 | } |
| 180 | } else if let Some(ref users) = self.user_pings { |
| 181 | if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING { |
| 182 | if !dst.poll_ready(cx)?.is_ready() { |
| 183 | return Poll::Pending; |
| 184 | } |
| 185 | |
| 186 | dst.buffer(Ping::new(Ping::USER).into()) |
| 187 | .expect("invalid ping frame" ); |
| 188 | users |
| 189 | .0 |
| 190 | .state |
| 191 | .store(USER_STATE_PENDING_PONG, Ordering::Release); |
| 192 | } else { |
| 193 | users.0.ping_task.register(cx.waker()); |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | Poll::Ready(Ok(())) |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | impl ReceivedPing { |
| 202 | pub(crate) fn is_shutdown(&self) -> bool { |
| 203 | matches!(*self, Self::Shutdown) |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | // ===== impl UserPings ===== |
| 208 | |
| 209 | impl UserPings { |
| 210 | pub(crate) fn send_ping(&self) -> Result<(), Option<proto::Error>> { |
| 211 | let prev = self |
| 212 | .0 |
| 213 | .state |
| 214 | .compare_exchange( |
| 215 | USER_STATE_EMPTY, // current |
| 216 | USER_STATE_PENDING_PING, // new |
| 217 | Ordering::AcqRel, |
| 218 | Ordering::Acquire, |
| 219 | ) |
| 220 | .unwrap_or_else(|v| v); |
| 221 | |
| 222 | match prev { |
| 223 | USER_STATE_EMPTY => { |
| 224 | self.0.ping_task.wake(); |
| 225 | Ok(()) |
| 226 | } |
| 227 | USER_STATE_CLOSED => Err(Some(broken_pipe().into())), |
| 228 | _ => { |
| 229 | // Was already pending, user error! |
| 230 | Err(None) |
| 231 | } |
| 232 | } |
| 233 | } |
| 234 | |
| 235 | pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>> { |
| 236 | // Must register before checking state, in case state were to change |
| 237 | // before we could register, and then the ping would just be lost. |
| 238 | self.0.pong_task.register(cx.waker()); |
| 239 | let prev = self |
| 240 | .0 |
| 241 | .state |
| 242 | .compare_exchange( |
| 243 | USER_STATE_RECEIVED_PONG, // current |
| 244 | USER_STATE_EMPTY, // new |
| 245 | Ordering::AcqRel, |
| 246 | Ordering::Acquire, |
| 247 | ) |
| 248 | .unwrap_or_else(|v| v); |
| 249 | |
| 250 | match prev { |
| 251 | USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())), |
| 252 | USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())), |
| 253 | _ => Poll::Pending, |
| 254 | } |
| 255 | } |
| 256 | } |
| 257 | |
| 258 | // ===== impl UserPingsRx ===== |
| 259 | |
| 260 | impl UserPingsRx { |
| 261 | fn receive_pong(&self) -> bool { |
| 262 | let prev: usize = self |
| 263 | .0 |
| 264 | .state |
| 265 | .compare_exchange( |
| 266 | USER_STATE_PENDING_PONG, // current |
| 267 | USER_STATE_RECEIVED_PONG, // new |
| 268 | Ordering::AcqRel, |
| 269 | Ordering::Acquire, |
| 270 | ) |
| 271 | .unwrap_or_else(|v: usize| v); |
| 272 | |
| 273 | if prev == USER_STATE_PENDING_PONG { |
| 274 | self.0.pong_task.wake(); |
| 275 | true |
| 276 | } else { |
| 277 | false |
| 278 | } |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | impl Drop for UserPingsRx { |
| 283 | fn drop(&mut self) { |
| 284 | self.0.state.store(USER_STATE_CLOSED, order:Ordering::Release); |
| 285 | self.0.pong_task.wake(); |
| 286 | } |
| 287 | } |
| 288 | |
| 289 | fn broken_pipe() -> io::Error { |
| 290 | io::ErrorKind::BrokenPipe.into() |
| 291 | } |
| 292 | |