1mod error;
2mod framed_read;
3mod framed_write;
4
5pub use self::error::{SendError, UserError};
6
7use self::framed_read::FramedRead;
8use self::framed_write::FramedWrite;
9
10use crate::frame::{self, Data, Frame};
11use crate::proto::Error;
12
13use bytes::Buf;
14use futures_core::Stream;
15use futures_sink::Sink;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::length_delimited;
20
21use std::io;
22
23#[derive(Debug)]
24pub struct Codec<T, B> {
25 inner: FramedRead<FramedWrite<T, B>>,
26}
27
28impl<T, B> Codec<T, B>
29where
30 T: AsyncRead + AsyncWrite + Unpin,
31 B: Buf,
32{
33 /// Returns a new `Codec` with the default max frame size
34 #[inline]
35 pub fn new(io: T) -> Self {
36 Self::with_max_recv_frame_size(io, frame::DEFAULT_MAX_FRAME_SIZE as usize)
37 }
38
39 /// Returns a new `Codec` with the given maximum frame size
40 pub fn with_max_recv_frame_size(io: T, max_frame_size: usize) -> Self {
41 // Wrap with writer
42 let framed_write = FramedWrite::new(io);
43
44 // Delimit the frames
45 let delimited = length_delimited::Builder::new()
46 .big_endian()
47 .length_field_length(3)
48 .length_adjustment(9)
49 .num_skip(0) // Don't skip the header
50 .new_read(framed_write);
51
52 let mut inner = FramedRead::new(delimited);
53
54 // Use FramedRead's method since it checks the value is within range.
55 inner.set_max_frame_size(max_frame_size);
56
57 Codec { inner }
58 }
59}
60
61impl<T, B> Codec<T, B> {
62 /// Updates the max received frame size.
63 ///
64 /// The change takes effect the next time a frame is decoded. In other
65 /// words, if a frame is currently in process of being decoded with a frame
66 /// size greater than `val` but less than the max frame size in effect
67 /// before calling this function, then the frame will be allowed.
68 #[inline]
69 pub fn set_max_recv_frame_size(&mut self, val: usize) {
70 self.inner.set_max_frame_size(val)
71 }
72
73 /// Returns the current max received frame size setting.
74 ///
75 /// This is the largest size this codec will accept from the wire. Larger
76 /// frames will be rejected.
77 #[cfg(feature = "unstable")]
78 #[inline]
79 pub fn max_recv_frame_size(&self) -> usize {
80 self.inner.max_frame_size()
81 }
82
83 /// Returns the max frame size that can be sent to the peer.
84 pub fn max_send_frame_size(&self) -> usize {
85 self.inner.get_ref().max_frame_size()
86 }
87
88 /// Set the peer's max frame size.
89 pub fn set_max_send_frame_size(&mut self, val: usize) {
90 self.framed_write().set_max_frame_size(val)
91 }
92
93 /// Set the peer's header table size size.
94 pub fn set_send_header_table_size(&mut self, val: usize) {
95 self.framed_write().set_header_table_size(val)
96 }
97
98 /// Set the decoder header table size size.
99 pub fn set_recv_header_table_size(&mut self, val: usize) {
100 self.inner.set_header_table_size(val)
101 }
102
103 /// Set the max header list size that can be received.
104 pub fn set_max_recv_header_list_size(&mut self, val: usize) {
105 self.inner.set_max_header_list_size(val);
106 }
107
108 /// Get a reference to the inner stream.
109 #[cfg(feature = "unstable")]
110 pub fn get_ref(&self) -> &T {
111 self.inner.get_ref().get_ref()
112 }
113
114 /// Get a mutable reference to the inner stream.
115 pub fn get_mut(&mut self) -> &mut T {
116 self.inner.get_mut().get_mut()
117 }
118
119 /// Takes the data payload value that was fully written to the socket
120 pub(crate) fn take_last_data_frame(&mut self) -> Option<Data<B>> {
121 self.framed_write().take_last_data_frame()
122 }
123
124 fn framed_write(&mut self) -> &mut FramedWrite<T, B> {
125 self.inner.get_mut()
126 }
127}
128
129impl<T, B> Codec<T, B>
130where
131 T: AsyncWrite + Unpin,
132 B: Buf,
133{
134 /// Returns `Ready` when the codec can buffer a frame
135 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
136 self.framed_write().poll_ready(cx)
137 }
138
139 /// Buffer a frame.
140 ///
141 /// `poll_ready` must be called first to ensure that a frame may be
142 /// accepted.
143 ///
144 /// TODO: Rename this to avoid conflicts with Sink::buffer
145 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
146 self.framed_write().buffer(item)
147 }
148
149 /// Flush buffered data to the wire
150 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
151 self.framed_write().flush(cx)
152 }
153
154 /// Shutdown the send half
155 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
156 self.framed_write().shutdown(cx)
157 }
158}
159
160impl<T, B> Stream for Codec<T, B>
161where
162 T: AsyncRead + Unpin,
163{
164 type Item = Result<Frame, Error>;
165
166 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167 Pin::new(&mut self.inner).poll_next(cx)
168 }
169}
170
171impl<T, B> Sink<Frame<B>> for Codec<T, B>
172where
173 T: AsyncWrite + Unpin,
174 B: Buf,
175{
176 type Error = SendError;
177
178 fn start_send(mut self: Pin<&mut Self>, item: Frame<B>) -> Result<(), Self::Error> {
179 Codec::buffer(&mut self, item)?;
180 Ok(())
181 }
182 /// Returns `Ready` when the codec can buffer a frame
183 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 self.framed_write().poll_ready(cx).map_err(Into::into)
185 }
186
187 /// Flush buffered data to the wire
188 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.framed_write().flush(cx).map_err(Into::into)
190 }
191
192 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193 ready!(self.shutdown(cx))?;
194 Poll::Ready(Ok(()))
195 }
196}
197
198// TODO: remove (or improve) this
199impl<T> From<T> for Codec<T, bytes::Bytes>
200where
201 T: AsyncRead + AsyncWrite + Unpin,
202{
203 fn from(src: T) -> Self {
204 Self::new(io:src)
205 }
206}
207