1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13
14type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15/// Header frame
16///
17/// This could be either a request or a response.
18#[derive(Eq, PartialEq)]
19pub struct Headers {
20 /// The ID of the stream with which this frame is associated.
21 stream_id: StreamId,
22
23 /// The stream dependency information, if any.
24 stream_dep: Option<StreamDependency>,
25
26 /// The header block fragment
27 header_block: HeaderBlock,
28
29 /// The associated flags
30 flags: HeadersFlag,
31}
32
33#[derive(Copy, Clone, Eq, PartialEq)]
34pub struct HeadersFlag(u8);
35
36#[derive(Eq, PartialEq)]
37pub struct PushPromise {
38 /// The ID of the stream with which this frame is associated.
39 stream_id: StreamId,
40
41 /// The ID of the stream being reserved by this PushPromise.
42 promised_id: StreamId,
43
44 /// The header block fragment
45 header_block: HeaderBlock,
46
47 /// The associated flags
48 flags: PushPromiseFlag,
49}
50
51#[derive(Copy, Clone, Eq, PartialEq)]
52pub struct PushPromiseFlag(u8);
53
54#[derive(Debug)]
55pub struct Continuation {
56 /// Stream ID of continuation frame
57 stream_id: StreamId,
58
59 header_block: EncodingHeaderBlock,
60}
61
62// TODO: These fields shouldn't be `pub`
63#[derive(Debug, Default, Eq, PartialEq)]
64pub struct Pseudo {
65 // Request
66 pub method: Option<Method>,
67 pub scheme: Option<BytesStr>,
68 pub authority: Option<BytesStr>,
69 pub path: Option<BytesStr>,
70 pub protocol: Option<Protocol>,
71
72 // Response
73 pub status: Option<StatusCode>,
74}
75
76#[derive(Debug)]
77pub struct Iter {
78 /// Pseudo headers
79 pseudo: Option<Pseudo>,
80
81 /// Header fields
82 fields: header::IntoIter<HeaderValue>,
83}
84
85#[derive(Debug, PartialEq, Eq)]
86struct HeaderBlock {
87 /// The decoded header fields
88 fields: HeaderMap,
89
90 /// Set to true if decoding went over the max header list size.
91 is_over_size: bool,
92
93 /// Pseudo headers, these are broken out as they must be sent as part of the
94 /// headers frame.
95 pseudo: Pseudo,
96}
97
98#[derive(Debug)]
99struct EncodingHeaderBlock {
100 hpack: Bytes,
101}
102
103const END_STREAM: u8 = 0x1;
104const END_HEADERS: u8 = 0x4;
105const PADDED: u8 = 0x8;
106const PRIORITY: u8 = 0x20;
107const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
108
109// ===== impl Headers =====
110
111impl Headers {
112 /// Create a new HEADERS frame
113 pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
114 Headers {
115 stream_id,
116 stream_dep: None,
117 header_block: HeaderBlock {
118 fields,
119 is_over_size: false,
120 pseudo,
121 },
122 flags: HeadersFlag::default(),
123 }
124 }
125
126 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
127 let mut flags = HeadersFlag::default();
128 flags.set_end_stream();
129
130 Headers {
131 stream_id,
132 stream_dep: None,
133 header_block: HeaderBlock {
134 fields,
135 is_over_size: false,
136 pseudo: Pseudo::default(),
137 },
138 flags,
139 }
140 }
141
142 /// Loads the header frame but doesn't actually do HPACK decoding.
143 ///
144 /// HPACK decoding is done in the `load_hpack` step.
145 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
146 let flags = HeadersFlag(head.flag());
147 let mut pad = 0;
148
149 tracing::trace!("loading headers; flags={:?}", flags);
150
151 if head.stream_id().is_zero() {
152 return Err(Error::InvalidStreamId);
153 }
154
155 // Read the padding length
156 if flags.is_padded() {
157 if src.is_empty() {
158 return Err(Error::MalformedMessage);
159 }
160 pad = src[0] as usize;
161
162 // Drop the padding
163 let _ = src.split_to(1);
164 }
165
166 // Read the stream dependency
167 let stream_dep = if flags.is_priority() {
168 if src.len() < 5 {
169 return Err(Error::MalformedMessage);
170 }
171 let stream_dep = StreamDependency::load(&src[..5])?;
172
173 if stream_dep.dependency_id() == head.stream_id() {
174 return Err(Error::InvalidDependencyId);
175 }
176
177 // Drop the next 5 bytes
178 let _ = src.split_to(5);
179
180 Some(stream_dep)
181 } else {
182 None
183 };
184
185 if pad > 0 {
186 if pad > src.len() {
187 return Err(Error::TooMuchPadding);
188 }
189
190 let len = src.len() - pad;
191 src.truncate(len);
192 }
193
194 let headers = Headers {
195 stream_id: head.stream_id(),
196 stream_dep,
197 header_block: HeaderBlock {
198 fields: HeaderMap::new(),
199 is_over_size: false,
200 pseudo: Pseudo::default(),
201 },
202 flags,
203 };
204
205 Ok((headers, src))
206 }
207
208 pub fn load_hpack(
209 &mut self,
210 src: &mut BytesMut,
211 max_header_list_size: usize,
212 decoder: &mut hpack::Decoder,
213 ) -> Result<(), Error> {
214 self.header_block.load(src, max_header_list_size, decoder)
215 }
216
217 pub fn stream_id(&self) -> StreamId {
218 self.stream_id
219 }
220
221 pub fn is_end_headers(&self) -> bool {
222 self.flags.is_end_headers()
223 }
224
225 pub fn set_end_headers(&mut self) {
226 self.flags.set_end_headers();
227 }
228
229 pub fn is_end_stream(&self) -> bool {
230 self.flags.is_end_stream()
231 }
232
233 pub fn set_end_stream(&mut self) {
234 self.flags.set_end_stream()
235 }
236
237 pub fn is_over_size(&self) -> bool {
238 self.header_block.is_over_size
239 }
240
241 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
242 (self.header_block.pseudo, self.header_block.fields)
243 }
244
245 #[cfg(feature = "unstable")]
246 pub fn pseudo_mut(&mut self) -> &mut Pseudo {
247 &mut self.header_block.pseudo
248 }
249
250 /// Whether it has status 1xx
251 pub(crate) fn is_informational(&self) -> bool {
252 self.header_block.pseudo.is_informational()
253 }
254
255 pub fn fields(&self) -> &HeaderMap {
256 &self.header_block.fields
257 }
258
259 pub fn into_fields(self) -> HeaderMap {
260 self.header_block.fields
261 }
262
263 pub fn encode(
264 self,
265 encoder: &mut hpack::Encoder,
266 dst: &mut EncodeBuf<'_>,
267 ) -> Option<Continuation> {
268 // At this point, the `is_end_headers` flag should always be set
269 debug_assert!(self.flags.is_end_headers());
270
271 // Get the HEADERS frame head
272 let head = self.head();
273
274 self.header_block
275 .into_encoding(encoder)
276 .encode(&head, dst, |_| {})
277 }
278
279 fn head(&self) -> Head {
280 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
281 }
282}
283
284impl<T> From<Headers> for Frame<T> {
285 fn from(src: Headers) -> Self {
286 Frame::Headers(src)
287 }
288}
289
290impl fmt::Debug for Headers {
291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 let mut builder: DebugStruct<'_, '_> = f.debug_struct(name:"Headers");
293 builder
294 .field("stream_id", &self.stream_id)
295 .field(name:"flags", &self.flags);
296
297 if let Some(ref protocol: &Protocol) = self.header_block.pseudo.protocol {
298 builder.field(name:"protocol", value:protocol);
299 }
300
301 if let Some(ref dep: &StreamDependency) = self.stream_dep {
302 builder.field(name:"stream_dep", value:dep);
303 }
304
305 // `fields` and `pseudo` purposefully not included
306 builder.finish()
307 }
308}
309
310// ===== util =====
311
312#[derive(Debug, PartialEq, Eq)]
313pub struct ParseU64Error;
314
315pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
316 if src.len() > 19 {
317 // At danger for overflow...
318 return Err(ParseU64Error);
319 }
320
321 let mut ret: u64 = 0;
322
323 for &d: u8 in src {
324 if d < b'0' || d > b'9' {
325 return Err(ParseU64Error);
326 }
327
328 ret *= 10;
329 ret += (d - b'0') as u64;
330 }
331
332 Ok(ret)
333}
334
335// ===== impl PushPromise =====
336
337#[derive(Debug)]
338pub enum PushPromiseHeaderError {
339 InvalidContentLength(Result<u64, ParseU64Error>),
340 NotSafeAndCacheable,
341}
342
343impl PushPromise {
344 pub fn new(
345 stream_id: StreamId,
346 promised_id: StreamId,
347 pseudo: Pseudo,
348 fields: HeaderMap,
349 ) -> Self {
350 PushPromise {
351 flags: PushPromiseFlag::default(),
352 header_block: HeaderBlock {
353 fields,
354 is_over_size: false,
355 pseudo,
356 },
357 promised_id,
358 stream_id,
359 }
360 }
361
362 pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
363 use PushPromiseHeaderError::*;
364 // The spec has some requirements for promised request headers
365 // [https://httpwg.org/specs/rfc7540.html#PushRequests]
366
367 // A promised request "that indicates the presence of a request body
368 // MUST reset the promised stream with a stream error"
369 if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
370 let parsed_length = parse_u64(content_length.as_bytes());
371 if parsed_length != Ok(0) {
372 return Err(InvalidContentLength(parsed_length));
373 }
374 }
375 // "The server MUST include a method in the :method pseudo-header field
376 // that is safe and cacheable"
377 if !Self::safe_and_cacheable(req.method()) {
378 return Err(NotSafeAndCacheable);
379 }
380
381 Ok(())
382 }
383
384 fn safe_and_cacheable(method: &Method) -> bool {
385 // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
386 // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
387 method == Method::GET || method == Method::HEAD
388 }
389
390 pub fn fields(&self) -> &HeaderMap {
391 &self.header_block.fields
392 }
393
394 #[cfg(feature = "unstable")]
395 pub fn into_fields(self) -> HeaderMap {
396 self.header_block.fields
397 }
398
399 /// Loads the push promise frame but doesn't actually do HPACK decoding.
400 ///
401 /// HPACK decoding is done in the `load_hpack` step.
402 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
403 let flags = PushPromiseFlag(head.flag());
404 let mut pad = 0;
405
406 if head.stream_id().is_zero() {
407 return Err(Error::InvalidStreamId);
408 }
409
410 // Read the padding length
411 if flags.is_padded() {
412 if src.is_empty() {
413 return Err(Error::MalformedMessage);
414 }
415
416 // TODO: Ensure payload is sized correctly
417 pad = src[0] as usize;
418
419 // Drop the padding
420 let _ = src.split_to(1);
421 }
422
423 if src.len() < 5 {
424 return Err(Error::MalformedMessage);
425 }
426
427 let (promised_id, _) = StreamId::parse(&src[..4]);
428 // Drop promised_id bytes
429 let _ = src.split_to(4);
430
431 if pad > 0 {
432 if pad > src.len() {
433 return Err(Error::TooMuchPadding);
434 }
435
436 let len = src.len() - pad;
437 src.truncate(len);
438 }
439
440 let frame = PushPromise {
441 flags,
442 header_block: HeaderBlock {
443 fields: HeaderMap::new(),
444 is_over_size: false,
445 pseudo: Pseudo::default(),
446 },
447 promised_id,
448 stream_id: head.stream_id(),
449 };
450 Ok((frame, src))
451 }
452
453 pub fn load_hpack(
454 &mut self,
455 src: &mut BytesMut,
456 max_header_list_size: usize,
457 decoder: &mut hpack::Decoder,
458 ) -> Result<(), Error> {
459 self.header_block.load(src, max_header_list_size, decoder)
460 }
461
462 pub fn stream_id(&self) -> StreamId {
463 self.stream_id
464 }
465
466 pub fn promised_id(&self) -> StreamId {
467 self.promised_id
468 }
469
470 pub fn is_end_headers(&self) -> bool {
471 self.flags.is_end_headers()
472 }
473
474 pub fn set_end_headers(&mut self) {
475 self.flags.set_end_headers();
476 }
477
478 pub fn is_over_size(&self) -> bool {
479 self.header_block.is_over_size
480 }
481
482 pub fn encode(
483 self,
484 encoder: &mut hpack::Encoder,
485 dst: &mut EncodeBuf<'_>,
486 ) -> Option<Continuation> {
487 // At this point, the `is_end_headers` flag should always be set
488 debug_assert!(self.flags.is_end_headers());
489
490 let head = self.head();
491 let promised_id = self.promised_id;
492
493 self.header_block
494 .into_encoding(encoder)
495 .encode(&head, dst, |dst| {
496 dst.put_u32(promised_id.into());
497 })
498 }
499
500 fn head(&self) -> Head {
501 Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
502 }
503
504 /// Consume `self`, returning the parts of the frame
505 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
506 (self.header_block.pseudo, self.header_block.fields)
507 }
508}
509
510impl<T> From<PushPromise> for Frame<T> {
511 fn from(src: PushPromise) -> Self {
512 Frame::PushPromise(src)
513 }
514}
515
516impl fmt::Debug for PushPromise {
517 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
518 f&mut DebugStruct<'_, '_>.debug_struct("PushPromise")
519 .field("stream_id", &self.stream_id)
520 .field("promised_id", &self.promised_id)
521 .field(name:"flags", &self.flags)
522 // `fields` and `pseudo` purposefully not included
523 .finish()
524 }
525}
526
527// ===== impl Continuation =====
528
529impl Continuation {
530 fn head(&self) -> Head {
531 Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
532 }
533
534 pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
535 // Get the CONTINUATION frame head
536 let head: Head = self.head();
537
538 self.header_block.encode(&head, dst, |_| {})
539 }
540}
541
542// ===== impl Pseudo =====
543
544impl Pseudo {
545 pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
546 let parts = uri::Parts::from(uri);
547
548 let mut path = parts
549 .path_and_query
550 .map(|v| BytesStr::from(v.as_str()))
551 .unwrap_or(BytesStr::from_static(""));
552
553 match method {
554 Method::OPTIONS | Method::CONNECT => {}
555 _ if path.is_empty() => {
556 path = BytesStr::from_static("/");
557 }
558 _ => {}
559 }
560
561 let mut pseudo = Pseudo {
562 method: Some(method),
563 scheme: None,
564 authority: None,
565 path: Some(path).filter(|p| !p.is_empty()),
566 protocol,
567 status: None,
568 };
569
570 // If the URI includes a scheme component, add it to the pseudo headers
571 //
572 // TODO: Scheme must be set...
573 if let Some(scheme) = parts.scheme {
574 pseudo.set_scheme(scheme);
575 }
576
577 // If the URI includes an authority component, add it to the pseudo
578 // headers
579 if let Some(authority) = parts.authority {
580 pseudo.set_authority(BytesStr::from(authority.as_str()));
581 }
582
583 pseudo
584 }
585
586 pub fn response(status: StatusCode) -> Self {
587 Pseudo {
588 method: None,
589 scheme: None,
590 authority: None,
591 path: None,
592 protocol: None,
593 status: Some(status),
594 }
595 }
596
597 #[cfg(feature = "unstable")]
598 pub fn set_status(&mut self, value: StatusCode) {
599 self.status = Some(value);
600 }
601
602 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
603 let bytes_str = match scheme.as_str() {
604 "http" => BytesStr::from_static("http"),
605 "https" => BytesStr::from_static("https"),
606 s => BytesStr::from(s),
607 };
608 self.scheme = Some(bytes_str);
609 }
610
611 #[cfg(feature = "unstable")]
612 pub fn set_protocol(&mut self, protocol: Protocol) {
613 self.protocol = Some(protocol);
614 }
615
616 pub fn set_authority(&mut self, authority: BytesStr) {
617 self.authority = Some(authority);
618 }
619
620 /// Whether it has status 1xx
621 pub(crate) fn is_informational(&self) -> bool {
622 self.status
623 .map_or(false, |status| status.is_informational())
624 }
625}
626
627// ===== impl EncodingHeaderBlock =====
628
629impl EncodingHeaderBlock {
630 fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
631 where
632 F: FnOnce(&mut EncodeBuf<'_>),
633 {
634 let head_pos = dst.get_ref().len();
635
636 // At this point, we don't know how big the h2 frame will be.
637 // So, we write the head with length 0, then write the body, and
638 // finally write the length once we know the size.
639 head.encode(0, dst);
640
641 let payload_pos = dst.get_ref().len();
642
643 f(dst);
644
645 // Now, encode the header payload
646 let continuation = if self.hpack.len() > dst.remaining_mut() {
647 dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
648
649 Some(Continuation {
650 stream_id: head.stream_id(),
651 header_block: self,
652 })
653 } else {
654 dst.put_slice(&self.hpack);
655
656 None
657 };
658
659 // Compute the header block length
660 let payload_len = (dst.get_ref().len() - payload_pos) as u64;
661
662 // Write the frame length
663 let payload_len_be = payload_len.to_be_bytes();
664 assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
665 (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
666
667 if continuation.is_some() {
668 // There will be continuation frames, so the `is_end_headers` flag
669 // must be unset
670 debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
671
672 dst.get_mut()[head_pos + 4] -= END_HEADERS;
673 }
674
675 continuation
676 }
677}
678
679// ===== impl Iter =====
680
681impl Iterator for Iter {
682 type Item = hpack::Header<Option<HeaderName>>;
683
684 fn next(&mut self) -> Option<Self::Item> {
685 use crate::hpack::Header::*;
686
687 if let Some(ref mut pseudo) = self.pseudo {
688 if let Some(method) = pseudo.method.take() {
689 return Some(Method(method));
690 }
691
692 if let Some(scheme) = pseudo.scheme.take() {
693 return Some(Scheme(scheme));
694 }
695
696 if let Some(authority) = pseudo.authority.take() {
697 return Some(Authority(authority));
698 }
699
700 if let Some(path) = pseudo.path.take() {
701 return Some(Path(path));
702 }
703
704 if let Some(protocol) = pseudo.protocol.take() {
705 return Some(Protocol(protocol));
706 }
707
708 if let Some(status) = pseudo.status.take() {
709 return Some(Status(status));
710 }
711 }
712
713 self.pseudo = None;
714
715 self.fields
716 .next()
717 .map(|(name, value)| Field { name, value })
718 }
719}
720
721// ===== impl HeadersFlag =====
722
723impl HeadersFlag {
724 pub fn empty() -> HeadersFlag {
725 HeadersFlag(0)
726 }
727
728 pub fn load(bits: u8) -> HeadersFlag {
729 HeadersFlag(bits & ALL)
730 }
731
732 pub fn is_end_stream(&self) -> bool {
733 self.0 & END_STREAM == END_STREAM
734 }
735
736 pub fn set_end_stream(&mut self) {
737 self.0 |= END_STREAM;
738 }
739
740 pub fn is_end_headers(&self) -> bool {
741 self.0 & END_HEADERS == END_HEADERS
742 }
743
744 pub fn set_end_headers(&mut self) {
745 self.0 |= END_HEADERS;
746 }
747
748 pub fn is_padded(&self) -> bool {
749 self.0 & PADDED == PADDED
750 }
751
752 pub fn is_priority(&self) -> bool {
753 self.0 & PRIORITY == PRIORITY
754 }
755}
756
757impl Default for HeadersFlag {
758 /// Returns a `HeadersFlag` value with `END_HEADERS` set.
759 fn default() -> Self {
760 HeadersFlag(END_HEADERS)
761 }
762}
763
764impl From<HeadersFlag> for u8 {
765 fn from(src: HeadersFlag) -> u8 {
766 src.0
767 }
768}
769
770impl fmt::Debug for HeadersFlag {
771 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
772 util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0)
773 .flag_if(self.is_end_headers(), "END_HEADERS")
774 .flag_if(self.is_end_stream(), "END_STREAM")
775 .flag_if(self.is_padded(), "PADDED")
776 .flag_if(self.is_priority(), name:"PRIORITY")
777 .finish()
778 }
779}
780
781// ===== impl PushPromiseFlag =====
782
783impl PushPromiseFlag {
784 pub fn empty() -> PushPromiseFlag {
785 PushPromiseFlag(0)
786 }
787
788 pub fn load(bits: u8) -> PushPromiseFlag {
789 PushPromiseFlag(bits & ALL)
790 }
791
792 pub fn is_end_headers(&self) -> bool {
793 self.0 & END_HEADERS == END_HEADERS
794 }
795
796 pub fn set_end_headers(&mut self) {
797 self.0 |= END_HEADERS;
798 }
799
800 pub fn is_padded(&self) -> bool {
801 self.0 & PADDED == PADDED
802 }
803}
804
805impl Default for PushPromiseFlag {
806 /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
807 fn default() -> Self {
808 PushPromiseFlag(END_HEADERS)
809 }
810}
811
812impl From<PushPromiseFlag> for u8 {
813 fn from(src: PushPromiseFlag) -> u8 {
814 src.0
815 }
816}
817
818impl fmt::Debug for PushPromiseFlag {
819 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
820 util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0)
821 .flag_if(self.is_end_headers(), "END_HEADERS")
822 .flag_if(self.is_padded(), name:"PADDED")
823 .finish()
824 }
825}
826
827// ===== HeaderBlock =====
828
829impl HeaderBlock {
830 fn load(
831 &mut self,
832 src: &mut BytesMut,
833 max_header_list_size: usize,
834 decoder: &mut hpack::Decoder,
835 ) -> Result<(), Error> {
836 let mut reg = !self.fields.is_empty();
837 let mut malformed = false;
838 let mut headers_size = self.calculate_header_list_size();
839
840 macro_rules! set_pseudo {
841 ($field:ident, $val:expr) => {{
842 if reg {
843 tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
844 malformed = true;
845 } else if self.pseudo.$field.is_some() {
846 tracing::trace!("load_hpack; header malformed -- repeated pseudo");
847 malformed = true;
848 } else {
849 let __val = $val;
850 headers_size +=
851 decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
852 if headers_size < max_header_list_size {
853 self.pseudo.$field = Some(__val);
854 } else if !self.is_over_size {
855 tracing::trace!("load_hpack; header list size over max");
856 self.is_over_size = true;
857 }
858 }
859 }};
860 }
861
862 let mut cursor = Cursor::new(src);
863
864 // If the header frame is malformed, we still have to continue decoding
865 // the headers. A malformed header frame is a stream level error, but
866 // the hpack state is connection level. In order to maintain correct
867 // state for other streams, the hpack decoding process must complete.
868 let res = decoder.decode(&mut cursor, |header| {
869 use crate::hpack::Header::*;
870
871 match header {
872 Field { name, value } => {
873 // Connection level header fields are not supported and must
874 // result in a protocol error.
875
876 if name == header::CONNECTION
877 || name == header::TRANSFER_ENCODING
878 || name == header::UPGRADE
879 || name == "keep-alive"
880 || name == "proxy-connection"
881 {
882 tracing::trace!("load_hpack; connection level header");
883 malformed = true;
884 } else if name == header::TE && value != "trailers" {
885 tracing::trace!(
886 "load_hpack; TE header not set to trailers; val={:?}",
887 value
888 );
889 malformed = true;
890 } else {
891 reg = true;
892
893 headers_size += decoded_header_size(name.as_str().len(), value.len());
894 if headers_size < max_header_list_size {
895 self.fields.append(name, value);
896 } else if !self.is_over_size {
897 tracing::trace!("load_hpack; header list size over max");
898 self.is_over_size = true;
899 }
900 }
901 }
902 Authority(v) => set_pseudo!(authority, v),
903 Method(v) => set_pseudo!(method, v),
904 Scheme(v) => set_pseudo!(scheme, v),
905 Path(v) => set_pseudo!(path, v),
906 Protocol(v) => set_pseudo!(protocol, v),
907 Status(v) => set_pseudo!(status, v),
908 }
909 });
910
911 if let Err(e) = res {
912 tracing::trace!("hpack decoding error; err={:?}", e);
913 return Err(e.into());
914 }
915
916 if malformed {
917 tracing::trace!("malformed message");
918 return Err(Error::MalformedMessage);
919 }
920
921 Ok(())
922 }
923
924 fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
925 let mut hpack = BytesMut::new();
926 let headers = Iter {
927 pseudo: Some(self.pseudo),
928 fields: self.fields.into_iter(),
929 };
930
931 encoder.encode(headers, &mut hpack);
932
933 EncodingHeaderBlock {
934 hpack: hpack.freeze(),
935 }
936 }
937
938 /// Calculates the size of the currently decoded header list.
939 ///
940 /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
941 ///
942 /// > The value is based on the uncompressed size of header fields,
943 /// > including the length of the name and value in octets plus an
944 /// > overhead of 32 octets for each header field.
945 fn calculate_header_list_size(&self) -> usize {
946 macro_rules! pseudo_size {
947 ($name:ident) => {{
948 self.pseudo
949 .$name
950 .as_ref()
951 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
952 .unwrap_or(0)
953 }};
954 }
955
956 pseudo_size!(method)
957 + pseudo_size!(scheme)
958 + pseudo_size!(status)
959 + pseudo_size!(authority)
960 + pseudo_size!(path)
961 + self
962 .fields
963 .iter()
964 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
965 .sum::<usize>()
966 }
967}
968
969fn decoded_header_size(name: usize, value: usize) -> usize {
970 name + value + 32
971}
972
973#[cfg(test)]
974mod test {
975 use std::iter::FromIterator;
976
977 use http::HeaderValue;
978
979 use super::*;
980 use crate::frame;
981 use crate::hpack::{huffman, Encoder};
982
983 #[test]
984 fn test_nameless_header_at_resume() {
985 let mut encoder = Encoder::default();
986 let mut dst = BytesMut::new();
987
988 let headers = Headers::new(
989 StreamId::ZERO,
990 Default::default(),
991 HeaderMap::from_iter(vec![
992 (
993 HeaderName::from_static("hello"),
994 HeaderValue::from_static("world"),
995 ),
996 (
997 HeaderName::from_static("hello"),
998 HeaderValue::from_static("zomg"),
999 ),
1000 (
1001 HeaderName::from_static("hello"),
1002 HeaderValue::from_static("sup"),
1003 ),
1004 ]),
1005 );
1006
1007 let continuation = headers
1008 .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1009 .unwrap();
1010
1011 assert_eq!(17, dst.len());
1012 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1013 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1014 assert_eq!("hello", huff_decode(&dst[11..15]));
1015 assert_eq!(0x80 | 4, dst[15]);
1016
1017 let mut world = dst[16..17].to_owned();
1018
1019 dst.clear();
1020
1021 assert!(continuation
1022 .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1023 .is_none());
1024
1025 world.extend_from_slice(&dst[9..12]);
1026 assert_eq!("world", huff_decode(&world));
1027
1028 assert_eq!(24, dst.len());
1029 assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1030
1031 // // Next is not indexed
1032 assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1033 assert_eq!("zomg", huff_decode(&dst[15..18]));
1034 assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1035 assert_eq!("sup", huff_decode(&dst[21..]));
1036 }
1037
1038 fn huff_decode(src: &[u8]) -> BytesMut {
1039 let mut buf = BytesMut::new();
1040 huffman::decode(src, &mut buf).unwrap()
1041 }
1042}
1043