1 | use crate::codec::{Decoder, Encoder}; |
2 | |
3 | use futures_core::Stream; |
4 | use tokio::{io::ReadBuf, net::UdpSocket}; |
5 | |
6 | use bytes::{BufMut, BytesMut}; |
7 | use futures_core::ready; |
8 | use futures_sink::Sink; |
9 | use std::pin::Pin; |
10 | use std::task::{Context, Poll}; |
11 | use std::{ |
12 | borrow::Borrow, |
13 | net::{Ipv4Addr, SocketAddr, SocketAddrV4}, |
14 | }; |
15 | use std::{io, mem::MaybeUninit}; |
16 | |
17 | /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using |
18 | /// the `Encoder` and `Decoder` traits to encode and decode frames. |
19 | /// |
20 | /// Raw UDP sockets work with datagrams, but higher-level code usually wants to |
21 | /// batch these into meaningful chunks, called "frames". This method layers |
22 | /// framing on top of this socket by using the `Encoder` and `Decoder` traits to |
23 | /// handle encoding and decoding of messages frames. Note that the incoming and |
24 | /// outgoing frame types may be distinct. |
25 | /// |
26 | /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; |
27 | /// grouping this into a single object is often useful for layering things which |
28 | /// require both read and write access to the underlying object. |
29 | /// |
30 | /// If you want to work more directly with the streams and sink, consider |
31 | /// calling [`split`] on the `UdpFramed` returned by this method, which will break |
32 | /// them into separate objects, allowing them to interact more easily. |
33 | /// |
34 | /// [`Stream`]: futures_core::Stream |
35 | /// [`Sink`]: futures_sink::Sink |
36 | /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split |
37 | #[must_use = "sinks do nothing unless polled" ] |
38 | #[derive(Debug)] |
39 | pub struct UdpFramed<C, T = UdpSocket> { |
40 | socket: T, |
41 | codec: C, |
42 | rd: BytesMut, |
43 | wr: BytesMut, |
44 | out_addr: SocketAddr, |
45 | flushed: bool, |
46 | is_readable: bool, |
47 | current_addr: Option<SocketAddr>, |
48 | } |
49 | |
50 | const INITIAL_RD_CAPACITY: usize = 64 * 1024; |
51 | const INITIAL_WR_CAPACITY: usize = 8 * 1024; |
52 | |
53 | impl<C, T> Unpin for UdpFramed<C, T> {} |
54 | |
55 | impl<C, T> Stream for UdpFramed<C, T> |
56 | where |
57 | T: Borrow<UdpSocket>, |
58 | C: Decoder, |
59 | { |
60 | type Item = Result<(C::Item, SocketAddr), C::Error>; |
61 | |
62 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
63 | let pin = self.get_mut(); |
64 | |
65 | pin.rd.reserve(INITIAL_RD_CAPACITY); |
66 | |
67 | loop { |
68 | // Are there still bytes left in the read buffer to decode? |
69 | if pin.is_readable { |
70 | if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { |
71 | let current_addr = pin |
72 | .current_addr |
73 | .expect("will always be set before this line is called" ); |
74 | |
75 | return Poll::Ready(Some(Ok((frame, current_addr)))); |
76 | } |
77 | |
78 | // if this line has been reached then decode has returned `None`. |
79 | pin.is_readable = false; |
80 | pin.rd.clear(); |
81 | } |
82 | |
83 | // We're out of data. Try and fetch more data to decode |
84 | let addr = { |
85 | // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a |
86 | // transparent wrapper around `[MaybeUninit<u8>]`. |
87 | let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) }; |
88 | let mut read = ReadBuf::uninit(buf); |
89 | let ptr = read.filled().as_ptr(); |
90 | let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); |
91 | |
92 | assert_eq!(ptr, read.filled().as_ptr()); |
93 | let addr = res?; |
94 | |
95 | // Safety: This is guaranteed to be the number of initialized (and read) bytes due |
96 | // to the invariants provided by `ReadBuf::filled`. |
97 | unsafe { pin.rd.advance_mut(read.filled().len()) }; |
98 | |
99 | addr |
100 | }; |
101 | |
102 | pin.current_addr = Some(addr); |
103 | pin.is_readable = true; |
104 | } |
105 | } |
106 | } |
107 | |
108 | impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> |
109 | where |
110 | T: Borrow<UdpSocket>, |
111 | C: Encoder<I>, |
112 | { |
113 | type Error = C::Error; |
114 | |
115 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
116 | if !self.flushed { |
117 | match self.poll_flush(cx)? { |
118 | Poll::Ready(()) => {} |
119 | Poll::Pending => return Poll::Pending, |
120 | } |
121 | } |
122 | |
123 | Poll::Ready(Ok(())) |
124 | } |
125 | |
126 | fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { |
127 | let (frame, out_addr) = item; |
128 | |
129 | let pin = self.get_mut(); |
130 | |
131 | pin.codec.encode(frame, &mut pin.wr)?; |
132 | pin.out_addr = out_addr; |
133 | pin.flushed = false; |
134 | |
135 | Ok(()) |
136 | } |
137 | |
138 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
139 | if self.flushed { |
140 | return Poll::Ready(Ok(())); |
141 | } |
142 | |
143 | let Self { |
144 | ref socket, |
145 | ref mut out_addr, |
146 | ref mut wr, |
147 | .. |
148 | } = *self; |
149 | |
150 | let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; |
151 | |
152 | let wrote_all = n == self.wr.len(); |
153 | self.wr.clear(); |
154 | self.flushed = true; |
155 | |
156 | let res = if wrote_all { |
157 | Ok(()) |
158 | } else { |
159 | Err(io::Error::new( |
160 | io::ErrorKind::Other, |
161 | "failed to write entire datagram to socket" , |
162 | ) |
163 | .into()) |
164 | }; |
165 | |
166 | Poll::Ready(res) |
167 | } |
168 | |
169 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
170 | ready!(self.poll_flush(cx))?; |
171 | Poll::Ready(Ok(())) |
172 | } |
173 | } |
174 | |
175 | impl<C, T> UdpFramed<C, T> |
176 | where |
177 | T: Borrow<UdpSocket>, |
178 | { |
179 | /// Create a new `UdpFramed` backed by the given socket and codec. |
180 | /// |
181 | /// See struct level documentation for more details. |
182 | pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { |
183 | Self { |
184 | socket, |
185 | codec, |
186 | out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), |
187 | rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), |
188 | wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), |
189 | flushed: true, |
190 | is_readable: false, |
191 | current_addr: None, |
192 | } |
193 | } |
194 | |
195 | /// Returns a reference to the underlying I/O stream wrapped by `Framed`. |
196 | /// |
197 | /// # Note |
198 | /// |
199 | /// Care should be taken to not tamper with the underlying stream of data |
200 | /// coming in as it may corrupt the stream of frames otherwise being worked |
201 | /// with. |
202 | pub fn get_ref(&self) -> &T { |
203 | &self.socket |
204 | } |
205 | |
206 | /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. |
207 | /// |
208 | /// # Note |
209 | /// |
210 | /// Care should be taken to not tamper with the underlying stream of data |
211 | /// coming in as it may corrupt the stream of frames otherwise being worked |
212 | /// with. |
213 | pub fn get_mut(&mut self) -> &mut T { |
214 | &mut self.socket |
215 | } |
216 | |
217 | /// Returns a reference to the underlying codec wrapped by |
218 | /// `Framed`. |
219 | /// |
220 | /// Note that care should be taken to not tamper with the underlying codec |
221 | /// as it may corrupt the stream of frames otherwise being worked with. |
222 | pub fn codec(&self) -> &C { |
223 | &self.codec |
224 | } |
225 | |
226 | /// Returns a mutable reference to the underlying codec wrapped by |
227 | /// `UdpFramed`. |
228 | /// |
229 | /// Note that care should be taken to not tamper with the underlying codec |
230 | /// as it may corrupt the stream of frames otherwise being worked with. |
231 | pub fn codec_mut(&mut self) -> &mut C { |
232 | &mut self.codec |
233 | } |
234 | |
235 | /// Returns a reference to the read buffer. |
236 | pub fn read_buffer(&self) -> &BytesMut { |
237 | &self.rd |
238 | } |
239 | |
240 | /// Returns a mutable reference to the read buffer. |
241 | pub fn read_buffer_mut(&mut self) -> &mut BytesMut { |
242 | &mut self.rd |
243 | } |
244 | |
245 | /// Consumes the `Framed`, returning its underlying I/O stream. |
246 | pub fn into_inner(self) -> T { |
247 | self.socket |
248 | } |
249 | } |
250 | |