| 1 | use crate::codec::{Decoder, Encoder}; |
| 2 | |
| 3 | use futures_core::Stream; |
| 4 | use tokio::{io::ReadBuf, net::UdpSocket}; |
| 5 | |
| 6 | use bytes::{BufMut, BytesMut}; |
| 7 | use futures_core::ready; |
| 8 | use futures_sink::Sink; |
| 9 | use std::pin::Pin; |
| 10 | use std::task::{Context, Poll}; |
| 11 | use std::{ |
| 12 | borrow::Borrow, |
| 13 | net::{Ipv4Addr, SocketAddr, SocketAddrV4}, |
| 14 | }; |
| 15 | use std::{io, mem::MaybeUninit}; |
| 16 | |
| 17 | /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using |
| 18 | /// the `Encoder` and `Decoder` traits to encode and decode frames. |
| 19 | /// |
| 20 | /// Raw UDP sockets work with datagrams, but higher-level code usually wants to |
| 21 | /// batch these into meaningful chunks, called "frames". This method layers |
| 22 | /// framing on top of this socket by using the `Encoder` and `Decoder` traits to |
| 23 | /// handle encoding and decoding of messages frames. Note that the incoming and |
| 24 | /// outgoing frame types may be distinct. |
| 25 | /// |
| 26 | /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; |
| 27 | /// grouping this into a single object is often useful for layering things which |
| 28 | /// require both read and write access to the underlying object. |
| 29 | /// |
| 30 | /// If you want to work more directly with the streams and sink, consider |
| 31 | /// calling [`split`] on the `UdpFramed` returned by this method, which will break |
| 32 | /// them into separate objects, allowing them to interact more easily. |
| 33 | /// |
| 34 | /// [`Stream`]: futures_core::Stream |
| 35 | /// [`Sink`]: futures_sink::Sink |
| 36 | /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split |
| 37 | #[must_use = "sinks do nothing unless polled" ] |
| 38 | #[derive(Debug)] |
| 39 | pub struct UdpFramed<C, T = UdpSocket> { |
| 40 | socket: T, |
| 41 | codec: C, |
| 42 | rd: BytesMut, |
| 43 | wr: BytesMut, |
| 44 | out_addr: SocketAddr, |
| 45 | flushed: bool, |
| 46 | is_readable: bool, |
| 47 | current_addr: Option<SocketAddr>, |
| 48 | } |
| 49 | |
| 50 | const INITIAL_RD_CAPACITY: usize = 64 * 1024; |
| 51 | const INITIAL_WR_CAPACITY: usize = 8 * 1024; |
| 52 | |
| 53 | impl<C, T> Unpin for UdpFramed<C, T> {} |
| 54 | |
| 55 | impl<C, T> Stream for UdpFramed<C, T> |
| 56 | where |
| 57 | T: Borrow<UdpSocket>, |
| 58 | C: Decoder, |
| 59 | { |
| 60 | type Item = Result<(C::Item, SocketAddr), C::Error>; |
| 61 | |
| 62 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 63 | let pin = self.get_mut(); |
| 64 | |
| 65 | pin.rd.reserve(INITIAL_RD_CAPACITY); |
| 66 | |
| 67 | loop { |
| 68 | // Are there still bytes left in the read buffer to decode? |
| 69 | if pin.is_readable { |
| 70 | if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { |
| 71 | let current_addr = pin |
| 72 | .current_addr |
| 73 | .expect("will always be set before this line is called" ); |
| 74 | |
| 75 | return Poll::Ready(Some(Ok((frame, current_addr)))); |
| 76 | } |
| 77 | |
| 78 | // if this line has been reached then decode has returned `None`. |
| 79 | pin.is_readable = false; |
| 80 | pin.rd.clear(); |
| 81 | } |
| 82 | |
| 83 | // We're out of data. Try and fetch more data to decode |
| 84 | let addr = { |
| 85 | // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a |
| 86 | // transparent wrapper around `[MaybeUninit<u8>]`. |
| 87 | let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) }; |
| 88 | let mut read = ReadBuf::uninit(buf); |
| 89 | let ptr = read.filled().as_ptr(); |
| 90 | let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); |
| 91 | |
| 92 | assert_eq!(ptr, read.filled().as_ptr()); |
| 93 | let addr = res?; |
| 94 | |
| 95 | // Safety: This is guaranteed to be the number of initialized (and read) bytes due |
| 96 | // to the invariants provided by `ReadBuf::filled`. |
| 97 | unsafe { pin.rd.advance_mut(read.filled().len()) }; |
| 98 | |
| 99 | addr |
| 100 | }; |
| 101 | |
| 102 | pin.current_addr = Some(addr); |
| 103 | pin.is_readable = true; |
| 104 | } |
| 105 | } |
| 106 | } |
| 107 | |
| 108 | impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> |
| 109 | where |
| 110 | T: Borrow<UdpSocket>, |
| 111 | C: Encoder<I>, |
| 112 | { |
| 113 | type Error = C::Error; |
| 114 | |
| 115 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 116 | if !self.flushed { |
| 117 | match self.poll_flush(cx)? { |
| 118 | Poll::Ready(()) => {} |
| 119 | Poll::Pending => return Poll::Pending, |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | Poll::Ready(Ok(())) |
| 124 | } |
| 125 | |
| 126 | fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { |
| 127 | let (frame, out_addr) = item; |
| 128 | |
| 129 | let pin = self.get_mut(); |
| 130 | |
| 131 | pin.codec.encode(frame, &mut pin.wr)?; |
| 132 | pin.out_addr = out_addr; |
| 133 | pin.flushed = false; |
| 134 | |
| 135 | Ok(()) |
| 136 | } |
| 137 | |
| 138 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 139 | if self.flushed { |
| 140 | return Poll::Ready(Ok(())); |
| 141 | } |
| 142 | |
| 143 | let Self { |
| 144 | ref socket, |
| 145 | ref mut out_addr, |
| 146 | ref mut wr, |
| 147 | .. |
| 148 | } = *self; |
| 149 | |
| 150 | let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; |
| 151 | |
| 152 | let wrote_all = n == self.wr.len(); |
| 153 | self.wr.clear(); |
| 154 | self.flushed = true; |
| 155 | |
| 156 | let res = if wrote_all { |
| 157 | Ok(()) |
| 158 | } else { |
| 159 | Err(io::Error::new( |
| 160 | io::ErrorKind::Other, |
| 161 | "failed to write entire datagram to socket" , |
| 162 | ) |
| 163 | .into()) |
| 164 | }; |
| 165 | |
| 166 | Poll::Ready(res) |
| 167 | } |
| 168 | |
| 169 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 170 | ready!(self.poll_flush(cx))?; |
| 171 | Poll::Ready(Ok(())) |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | impl<C, T> UdpFramed<C, T> |
| 176 | where |
| 177 | T: Borrow<UdpSocket>, |
| 178 | { |
| 179 | /// Create a new `UdpFramed` backed by the given socket and codec. |
| 180 | /// |
| 181 | /// See struct level documentation for more details. |
| 182 | pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { |
| 183 | Self { |
| 184 | socket, |
| 185 | codec, |
| 186 | out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), |
| 187 | rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), |
| 188 | wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), |
| 189 | flushed: true, |
| 190 | is_readable: false, |
| 191 | current_addr: None, |
| 192 | } |
| 193 | } |
| 194 | |
| 195 | /// Returns a reference to the underlying I/O stream wrapped by `Framed`. |
| 196 | /// |
| 197 | /// # Note |
| 198 | /// |
| 199 | /// Care should be taken to not tamper with the underlying stream of data |
| 200 | /// coming in as it may corrupt the stream of frames otherwise being worked |
| 201 | /// with. |
| 202 | pub fn get_ref(&self) -> &T { |
| 203 | &self.socket |
| 204 | } |
| 205 | |
| 206 | /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. |
| 207 | /// |
| 208 | /// # Note |
| 209 | /// |
| 210 | /// Care should be taken to not tamper with the underlying stream of data |
| 211 | /// coming in as it may corrupt the stream of frames otherwise being worked |
| 212 | /// with. |
| 213 | pub fn get_mut(&mut self) -> &mut T { |
| 214 | &mut self.socket |
| 215 | } |
| 216 | |
| 217 | /// Returns a reference to the underlying codec wrapped by |
| 218 | /// `Framed`. |
| 219 | /// |
| 220 | /// Note that care should be taken to not tamper with the underlying codec |
| 221 | /// as it may corrupt the stream of frames otherwise being worked with. |
| 222 | pub fn codec(&self) -> &C { |
| 223 | &self.codec |
| 224 | } |
| 225 | |
| 226 | /// Returns a mutable reference to the underlying codec wrapped by |
| 227 | /// `UdpFramed`. |
| 228 | /// |
| 229 | /// Note that care should be taken to not tamper with the underlying codec |
| 230 | /// as it may corrupt the stream of frames otherwise being worked with. |
| 231 | pub fn codec_mut(&mut self) -> &mut C { |
| 232 | &mut self.codec |
| 233 | } |
| 234 | |
| 235 | /// Returns a reference to the read buffer. |
| 236 | pub fn read_buffer(&self) -> &BytesMut { |
| 237 | &self.rd |
| 238 | } |
| 239 | |
| 240 | /// Returns a mutable reference to the read buffer. |
| 241 | pub fn read_buffer_mut(&mut self) -> &mut BytesMut { |
| 242 | &mut self.rd |
| 243 | } |
| 244 | |
| 245 | /// Consumes the `Framed`, returning its underlying I/O stream. |
| 246 | pub fn into_inner(self) -> T { |
| 247 | self.socket |
| 248 | } |
| 249 | } |
| 250 | |