1use crate::codec::{Decoder, Encoder};
2
3use futures_core::Stream;
4use tokio::{io::ReadBuf, net::UdpSocket};
5
6use bytes::{BufMut, BytesMut};
7use futures_core::ready;
8use futures_sink::Sink;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{
12 borrow::Borrow,
13 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
14};
15use std::{io, mem::MaybeUninit};
16
17/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
18/// the `Encoder` and `Decoder` traits to encode and decode frames.
19///
20/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
21/// batch these into meaningful chunks, called "frames". This method layers
22/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
23/// handle encoding and decoding of messages frames. Note that the incoming and
24/// outgoing frame types may be distinct.
25///
26/// This function returns a *single* object that is both [`Stream`] and [`Sink`];
27/// grouping this into a single object is often useful for layering things which
28/// require both read and write access to the underlying object.
29///
30/// If you want to work more directly with the streams and sink, consider
31/// calling [`split`] on the `UdpFramed` returned by this method, which will break
32/// them into separate objects, allowing them to interact more easily.
33///
34/// [`Stream`]: futures_core::Stream
35/// [`Sink`]: futures_sink::Sink
36/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
37#[must_use = "sinks do nothing unless polled"]
38#[derive(Debug)]
39pub struct UdpFramed<C, T = UdpSocket> {
40 socket: T,
41 codec: C,
42 rd: BytesMut,
43 wr: BytesMut,
44 out_addr: SocketAddr,
45 flushed: bool,
46 is_readable: bool,
47 current_addr: Option<SocketAddr>,
48}
49
50const INITIAL_RD_CAPACITY: usize = 64 * 1024;
51const INITIAL_WR_CAPACITY: usize = 8 * 1024;
52
53impl<C, T> Unpin for UdpFramed<C, T> {}
54
55impl<C, T> Stream for UdpFramed<C, T>
56where
57 T: Borrow<UdpSocket>,
58 C: Decoder,
59{
60 type Item = Result<(C::Item, SocketAddr), C::Error>;
61
62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63 let pin = self.get_mut();
64
65 pin.rd.reserve(INITIAL_RD_CAPACITY);
66
67 loop {
68 // Are there still bytes left in the read buffer to decode?
69 if pin.is_readable {
70 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
71 let current_addr = pin
72 .current_addr
73 .expect("will always be set before this line is called");
74
75 return Poll::Ready(Some(Ok((frame, current_addr))));
76 }
77
78 // if this line has been reached then decode has returned `None`.
79 pin.is_readable = false;
80 pin.rd.clear();
81 }
82
83 // We're out of data. Try and fetch more data to decode
84 let addr = {
85 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
86 // transparent wrapper around `[MaybeUninit<u8>]`.
87 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
88 let mut read = ReadBuf::uninit(buf);
89 let ptr = read.filled().as_ptr();
90 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
91
92 assert_eq!(ptr, read.filled().as_ptr());
93 let addr = res?;
94
95 // Safety: This is guaranteed to be the number of initialized (and read) bytes due
96 // to the invariants provided by `ReadBuf::filled`.
97 unsafe { pin.rd.advance_mut(read.filled().len()) };
98
99 addr
100 };
101
102 pin.current_addr = Some(addr);
103 pin.is_readable = true;
104 }
105 }
106}
107
108impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
109where
110 T: Borrow<UdpSocket>,
111 C: Encoder<I>,
112{
113 type Error = C::Error;
114
115 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 if !self.flushed {
117 match self.poll_flush(cx)? {
118 Poll::Ready(()) => {}
119 Poll::Pending => return Poll::Pending,
120 }
121 }
122
123 Poll::Ready(Ok(()))
124 }
125
126 fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
127 let (frame, out_addr) = item;
128
129 let pin = self.get_mut();
130
131 pin.codec.encode(frame, &mut pin.wr)?;
132 pin.out_addr = out_addr;
133 pin.flushed = false;
134
135 Ok(())
136 }
137
138 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139 if self.flushed {
140 return Poll::Ready(Ok(()));
141 }
142
143 let Self {
144 ref socket,
145 ref mut out_addr,
146 ref mut wr,
147 ..
148 } = *self;
149
150 let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
151
152 let wrote_all = n == self.wr.len();
153 self.wr.clear();
154 self.flushed = true;
155
156 let res = if wrote_all {
157 Ok(())
158 } else {
159 Err(io::Error::new(
160 io::ErrorKind::Other,
161 "failed to write entire datagram to socket",
162 )
163 .into())
164 };
165
166 Poll::Ready(res)
167 }
168
169 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170 ready!(self.poll_flush(cx))?;
171 Poll::Ready(Ok(()))
172 }
173}
174
175impl<C, T> UdpFramed<C, T>
176where
177 T: Borrow<UdpSocket>,
178{
179 /// Create a new `UdpFramed` backed by the given socket and codec.
180 ///
181 /// See struct level documentation for more details.
182 pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
183 Self {
184 socket,
185 codec,
186 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
187 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
188 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
189 flushed: true,
190 is_readable: false,
191 current_addr: None,
192 }
193 }
194
195 /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
196 ///
197 /// # Note
198 ///
199 /// Care should be taken to not tamper with the underlying stream of data
200 /// coming in as it may corrupt the stream of frames otherwise being worked
201 /// with.
202 pub fn get_ref(&self) -> &T {
203 &self.socket
204 }
205
206 /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
207 ///
208 /// # Note
209 ///
210 /// Care should be taken to not tamper with the underlying stream of data
211 /// coming in as it may corrupt the stream of frames otherwise being worked
212 /// with.
213 pub fn get_mut(&mut self) -> &mut T {
214 &mut self.socket
215 }
216
217 /// Returns a reference to the underlying codec wrapped by
218 /// `Framed`.
219 ///
220 /// Note that care should be taken to not tamper with the underlying codec
221 /// as it may corrupt the stream of frames otherwise being worked with.
222 pub fn codec(&self) -> &C {
223 &self.codec
224 }
225
226 /// Returns a mutable reference to the underlying codec wrapped by
227 /// `UdpFramed`.
228 ///
229 /// Note that care should be taken to not tamper with the underlying codec
230 /// as it may corrupt the stream of frames otherwise being worked with.
231 pub fn codec_mut(&mut self) -> &mut C {
232 &mut self.codec
233 }
234
235 /// Returns a reference to the read buffer.
236 pub fn read_buffer(&self) -> &BytesMut {
237 &self.rd
238 }
239
240 /// Returns a mutable reference to the read buffer.
241 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
242 &mut self.rd
243 }
244
245 /// Consumes the `Framed`, returning its underlying I/O stream.
246 pub fn into_inner(self) -> T {
247 self.socket
248 }
249}
250