1use crate::frame::{self, Frame, Kind, Reason};
2use crate::frame::{
3 DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
4};
5use crate::proto::Error;
6
7use crate::hpack;
8
9use futures_core::Stream;
10
11use bytes::BytesMut;
12
13use std::io;
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::AsyncRead;
18use tokio_util::codec::FramedRead as InnerFramedRead;
19use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20
21// 16 MB "sane default" taken from golang http2
22const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24#[derive(Debug)]
25pub struct FramedRead<T> {
26 inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28 // hpack decoder state
29 hpack: hpack::Decoder,
30
31 max_header_list_size: usize,
32
33 partial: Option<Partial>,
34}
35
36/// Partially loaded headers frame
37#[derive(Debug)]
38struct Partial {
39 /// Empty frame
40 frame: Continuable,
41
42 /// Partial header payload
43 buf: BytesMut,
44}
45
46#[derive(Debug)]
47enum Continuable {
48 Headers(frame::Headers),
49 PushPromise(frame::PushPromise),
50}
51
52impl<T> FramedRead<T> {
53 pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
54 FramedRead {
55 inner,
56 hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
57 max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
58 partial: None,
59 }
60 }
61
62 pub fn get_ref(&self) -> &T {
63 self.inner.get_ref()
64 }
65
66 pub fn get_mut(&mut self) -> &mut T {
67 self.inner.get_mut()
68 }
69
70 /// Returns the current max frame size setting
71 #[cfg(feature = "unstable")]
72 #[inline]
73 pub fn max_frame_size(&self) -> usize {
74 self.inner.decoder().max_frame_length()
75 }
76
77 /// Updates the max frame size setting.
78 ///
79 /// Must be within 16,384 and 16,777,215.
80 #[inline]
81 pub fn set_max_frame_size(&mut self, val: usize) {
82 assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
83 self.inner.decoder_mut().set_max_frame_length(val)
84 }
85
86 /// Update the max header list size setting.
87 #[inline]
88 pub fn set_max_header_list_size(&mut self, val: usize) {
89 self.max_header_list_size = val;
90 }
91
92 /// Update the header table size setting.
93 #[inline]
94 pub fn set_header_table_size(&mut self, val: usize) {
95 self.hpack.queue_size_update(val);
96 }
97}
98
99/// Decodes a frame.
100///
101/// This method is intentionally de-generified and outlined because it is very large.
102fn decode_frame(
103 hpack: &mut hpack::Decoder,
104 max_header_list_size: usize,
105 partial_inout: &mut Option<Partial>,
106 mut bytes: BytesMut,
107) -> Result<Option<Frame>, Error> {
108 let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
109 let _e = span.enter();
110
111 tracing::trace!("decoding frame from {}B", bytes.len());
112
113 // Parse the head
114 let head = frame::Head::parse(&bytes);
115
116 if partial_inout.is_some() && head.kind() != Kind::Continuation {
117 proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
118 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
119 }
120
121 let kind = head.kind();
122
123 tracing::trace!(frame.kind = ?kind);
124
125 macro_rules! header_block {
126 ($frame:ident, $head:ident, $bytes:ident) => ({
127 // Drop the frame header
128 // TODO: Change to drain: carllerche/bytes#130
129 let _ = $bytes.split_to(frame::HEADER_LEN);
130
131 // Parse the header frame w/o parsing the payload
132 let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
133 Ok(res) => res,
134 Err(frame::Error::InvalidDependencyId) => {
135 proto_err!(stream: "invalid HEADERS dependency ID");
136 // A stream cannot depend on itself. An endpoint MUST
137 // treat this as a stream error (Section 5.4.2) of type
138 // `PROTOCOL_ERROR`.
139 return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
140 },
141 Err(e) => {
142 proto_err!(conn: "failed to load frame; err={:?}", e);
143 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
144 }
145 };
146
147 let is_end_headers = frame.is_end_headers();
148
149 // Load the HPACK encoded headers
150 match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
151 Ok(_) => {},
152 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
153 Err(frame::Error::MalformedMessage) => {
154 let id = $head.stream_id();
155 proto_err!(stream: "malformed header block; stream={:?}", id);
156 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
157 },
158 Err(e) => {
159 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
160 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
161 }
162 }
163
164 if is_end_headers {
165 frame.into()
166 } else {
167 tracing::trace!("loaded partial header block");
168 // Defer returning the frame
169 *partial_inout = Some(Partial {
170 frame: Continuable::$frame(frame),
171 buf: payload,
172 });
173
174 return Ok(None);
175 }
176 });
177 }
178
179 let frame = match kind {
180 Kind::Settings => {
181 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
182
183 res.map_err(|e| {
184 proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
185 Error::library_go_away(Reason::PROTOCOL_ERROR)
186 })?
187 .into()
188 }
189 Kind::Ping => {
190 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
191
192 res.map_err(|e| {
193 proto_err!(conn: "failed to load PING frame; err={:?}", e);
194 Error::library_go_away(Reason::PROTOCOL_ERROR)
195 })?
196 .into()
197 }
198 Kind::WindowUpdate => {
199 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
200
201 res.map_err(|e| {
202 proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
203 Error::library_go_away(Reason::PROTOCOL_ERROR)
204 })?
205 .into()
206 }
207 Kind::Data => {
208 let _ = bytes.split_to(frame::HEADER_LEN);
209 let res = frame::Data::load(head, bytes.freeze());
210
211 // TODO: Should this always be connection level? Probably not...
212 res.map_err(|e| {
213 proto_err!(conn: "failed to load DATA frame; err={:?}", e);
214 Error::library_go_away(Reason::PROTOCOL_ERROR)
215 })?
216 .into()
217 }
218 Kind::Headers => header_block!(Headers, head, bytes),
219 Kind::Reset => {
220 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
221 res.map_err(|e| {
222 proto_err!(conn: "failed to load RESET frame; err={:?}", e);
223 Error::library_go_away(Reason::PROTOCOL_ERROR)
224 })?
225 .into()
226 }
227 Kind::GoAway => {
228 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
229 res.map_err(|e| {
230 proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
231 Error::library_go_away(Reason::PROTOCOL_ERROR)
232 })?
233 .into()
234 }
235 Kind::PushPromise => header_block!(PushPromise, head, bytes),
236 Kind::Priority => {
237 if head.stream_id() == 0 {
238 // Invalid stream identifier
239 proto_err!(conn: "invalid stream ID 0");
240 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
241 }
242
243 match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
244 Ok(frame) => frame.into(),
245 Err(frame::Error::InvalidDependencyId) => {
246 // A stream cannot depend on itself. An endpoint MUST
247 // treat this as a stream error (Section 5.4.2) of type
248 // `PROTOCOL_ERROR`.
249 let id = head.stream_id();
250 proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
251 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
252 }
253 Err(e) => {
254 proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
255 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
256 }
257 }
258 }
259 Kind::Continuation => {
260 let is_end_headers = (head.flag() & 0x4) == 0x4;
261
262 let mut partial = match partial_inout.take() {
263 Some(partial) => partial,
264 None => {
265 proto_err!(conn: "received unexpected CONTINUATION frame");
266 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
267 }
268 };
269
270 // The stream identifiers must match
271 if partial.frame.stream_id() != head.stream_id() {
272 proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
273 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
274 }
275
276 // Extend the buf
277 if partial.buf.is_empty() {
278 partial.buf = bytes.split_off(frame::HEADER_LEN);
279 } else {
280 if partial.frame.is_over_size() {
281 // If there was left over bytes previously, they may be
282 // needed to continue decoding, even though we will
283 // be ignoring this frame. This is done to keep the HPACK
284 // decoder state up-to-date.
285 //
286 // Still, we need to be careful, because if a malicious
287 // attacker were to try to send a gigantic string, such
288 // that it fits over multiple header blocks, we could
289 // grow memory uncontrollably again, and that'd be a shame.
290 //
291 // Instead, we use a simple heuristic to determine if
292 // we should continue to ignore decoding, or to tell
293 // the attacker to go away.
294 if partial.buf.len() + bytes.len() > max_header_list_size {
295 proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
296 return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
297 }
298 }
299 partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
300 }
301
302 match partial
303 .frame
304 .load_hpack(&mut partial.buf, max_header_list_size, hpack)
305 {
306 Ok(_) => {}
307 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
308 Err(frame::Error::MalformedMessage) => {
309 let id = head.stream_id();
310 proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
311 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
312 }
313 Err(e) => {
314 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
315 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
316 }
317 }
318
319 if is_end_headers {
320 partial.frame.into()
321 } else {
322 *partial_inout = Some(partial);
323 return Ok(None);
324 }
325 }
326 Kind::Unknown => {
327 // Unknown frames are ignored
328 return Ok(None);
329 }
330 };
331
332 Ok(Some(frame))
333}
334
335impl<T> Stream for FramedRead<T>
336where
337 T: AsyncRead + Unpin,
338{
339 type Item = Result<Frame, Error>;
340
341 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342 let span = tracing::trace_span!("FramedRead::poll_next");
343 let _e = span.enter();
344 loop {
345 tracing::trace!("poll");
346 let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
347 Some(Ok(bytes)) => bytes,
348 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
349 None => return Poll::Ready(None),
350 };
351
352 tracing::trace!(read.bytes = bytes.len());
353 let Self {
354 ref mut hpack,
355 max_header_list_size,
356 ref mut partial,
357 ..
358 } = *self;
359 if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? {
360 tracing::debug!(?frame, "received");
361 return Poll::Ready(Some(Ok(frame)));
362 }
363 }
364 }
365}
366
367fn map_err(err: io::Error) -> Error {
368 if let io::ErrorKind::InvalidData = err.kind() {
369 if let Some(custom: &(dyn Error + Sync + Send)) = err.get_ref() {
370 if custom.is::<LengthDelimitedCodecError>() {
371 return Error::library_go_away(reason:Reason::FRAME_SIZE_ERROR);
372 }
373 }
374 }
375 err.into()
376}
377
378// ===== impl Continuable =====
379
380impl Continuable {
381 fn stream_id(&self) -> frame::StreamId {
382 match *self {
383 Continuable::Headers(ref h) => h.stream_id(),
384 Continuable::PushPromise(ref p) => p.stream_id(),
385 }
386 }
387
388 fn is_over_size(&self) -> bool {
389 match *self {
390 Continuable::Headers(ref h) => h.is_over_size(),
391 Continuable::PushPromise(ref p) => p.is_over_size(),
392 }
393 }
394
395 fn load_hpack(
396 &mut self,
397 src: &mut BytesMut,
398 max_header_list_size: usize,
399 decoder: &mut hpack::Decoder,
400 ) -> Result<(), frame::Error> {
401 match *self {
402 Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
403 Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
404 }
405 }
406}
407
408impl<T> From<Continuable> for Frame<T> {
409 fn from(cont: Continuable) -> Self {
410 match cont {
411 Continuable::Headers(mut headers: Headers) => {
412 headers.set_end_headers();
413 headers.into()
414 }
415 Continuable::PushPromise(mut push: PushPromise) => {
416 push.set_end_headers();
417 push.into()
418 }
419 }
420 }
421}
422