1 | use super::{util, StreamDependency, StreamId}; |
2 | use crate::ext::Protocol; |
3 | use crate::frame::{Error, Frame, Head, Kind}; |
4 | use crate::hpack::{self, BytesStr}; |
5 | |
6 | use http::header::{self, HeaderName, HeaderValue}; |
7 | use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; |
8 | |
9 | use bytes::{Buf, BufMut, Bytes, BytesMut}; |
10 | |
11 | use std::fmt; |
12 | use std::io::Cursor; |
13 | |
14 | type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; |
15 | |
16 | /// Header frame |
17 | /// |
18 | /// This could be either a request or a response. |
19 | #[derive (Eq, PartialEq)] |
20 | pub struct Headers { |
21 | /// The ID of the stream with which this frame is associated. |
22 | stream_id: StreamId, |
23 | |
24 | /// The stream dependency information, if any. |
25 | stream_dep: Option<StreamDependency>, |
26 | |
27 | /// The header block fragment |
28 | header_block: HeaderBlock, |
29 | |
30 | /// The associated flags |
31 | flags: HeadersFlag, |
32 | } |
33 | |
34 | #[derive (Copy, Clone, Eq, PartialEq)] |
35 | pub struct HeadersFlag(u8); |
36 | |
37 | #[derive (Eq, PartialEq)] |
38 | pub struct PushPromise { |
39 | /// The ID of the stream with which this frame is associated. |
40 | stream_id: StreamId, |
41 | |
42 | /// The ID of the stream being reserved by this PushPromise. |
43 | promised_id: StreamId, |
44 | |
45 | /// The header block fragment |
46 | header_block: HeaderBlock, |
47 | |
48 | /// The associated flags |
49 | flags: PushPromiseFlag, |
50 | } |
51 | |
52 | #[derive (Copy, Clone, Eq, PartialEq)] |
53 | pub struct PushPromiseFlag(u8); |
54 | |
55 | #[derive (Debug)] |
56 | pub struct Continuation { |
57 | /// Stream ID of continuation frame |
58 | stream_id: StreamId, |
59 | |
60 | header_block: EncodingHeaderBlock, |
61 | } |
62 | |
63 | // TODO: These fields shouldn't be `pub` |
64 | #[derive (Debug, Default, Eq, PartialEq)] |
65 | pub struct Pseudo { |
66 | // Request |
67 | pub method: Option<Method>, |
68 | pub scheme: Option<BytesStr>, |
69 | pub authority: Option<BytesStr>, |
70 | pub path: Option<BytesStr>, |
71 | pub protocol: Option<Protocol>, |
72 | |
73 | // Response |
74 | pub status: Option<StatusCode>, |
75 | } |
76 | |
77 | #[derive (Debug)] |
78 | pub struct Iter { |
79 | /// Pseudo headers |
80 | pseudo: Option<Pseudo>, |
81 | |
82 | /// Header fields |
83 | fields: header::IntoIter<HeaderValue>, |
84 | } |
85 | |
86 | #[derive (Debug, PartialEq, Eq)] |
87 | struct HeaderBlock { |
88 | /// The decoded header fields |
89 | fields: HeaderMap, |
90 | |
91 | /// Precomputed size of all of our header fields, for perf reasons |
92 | field_size: usize, |
93 | |
94 | /// Set to true if decoding went over the max header list size. |
95 | is_over_size: bool, |
96 | |
97 | /// Pseudo headers, these are broken out as they must be sent as part of the |
98 | /// headers frame. |
99 | pseudo: Pseudo, |
100 | } |
101 | |
102 | #[derive (Debug)] |
103 | struct EncodingHeaderBlock { |
104 | hpack: Bytes, |
105 | } |
106 | |
107 | const END_STREAM: u8 = 0x1; |
108 | const END_HEADERS: u8 = 0x4; |
109 | const PADDED: u8 = 0x8; |
110 | const PRIORITY: u8 = 0x20; |
111 | const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY; |
112 | |
113 | // ===== impl Headers ===== |
114 | |
115 | impl Headers { |
116 | /// Create a new HEADERS frame |
117 | pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self { |
118 | Headers { |
119 | stream_id, |
120 | stream_dep: None, |
121 | header_block: HeaderBlock { |
122 | field_size: calculate_headermap_size(&fields), |
123 | fields, |
124 | is_over_size: false, |
125 | pseudo, |
126 | }, |
127 | flags: HeadersFlag::default(), |
128 | } |
129 | } |
130 | |
131 | pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self { |
132 | let mut flags = HeadersFlag::default(); |
133 | flags.set_end_stream(); |
134 | |
135 | Headers { |
136 | stream_id, |
137 | stream_dep: None, |
138 | header_block: HeaderBlock { |
139 | field_size: calculate_headermap_size(&fields), |
140 | fields, |
141 | is_over_size: false, |
142 | pseudo: Pseudo::default(), |
143 | }, |
144 | flags, |
145 | } |
146 | } |
147 | |
148 | /// Loads the header frame but doesn't actually do HPACK decoding. |
149 | /// |
150 | /// HPACK decoding is done in the `load_hpack` step. |
151 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
152 | let flags = HeadersFlag(head.flag()); |
153 | let mut pad = 0; |
154 | |
155 | tracing::trace!("loading headers; flags= {:?}" , flags); |
156 | |
157 | if head.stream_id().is_zero() { |
158 | return Err(Error::InvalidStreamId); |
159 | } |
160 | |
161 | // Read the padding length |
162 | if flags.is_padded() { |
163 | if src.is_empty() { |
164 | return Err(Error::MalformedMessage); |
165 | } |
166 | pad = src[0] as usize; |
167 | |
168 | // Drop the padding |
169 | src.advance(1); |
170 | } |
171 | |
172 | // Read the stream dependency |
173 | let stream_dep = if flags.is_priority() { |
174 | if src.len() < 5 { |
175 | return Err(Error::MalformedMessage); |
176 | } |
177 | let stream_dep = StreamDependency::load(&src[..5])?; |
178 | |
179 | if stream_dep.dependency_id() == head.stream_id() { |
180 | return Err(Error::InvalidDependencyId); |
181 | } |
182 | |
183 | // Drop the next 5 bytes |
184 | src.advance(5); |
185 | |
186 | Some(stream_dep) |
187 | } else { |
188 | None |
189 | }; |
190 | |
191 | if pad > 0 { |
192 | if pad > src.len() { |
193 | return Err(Error::TooMuchPadding); |
194 | } |
195 | |
196 | let len = src.len() - pad; |
197 | src.truncate(len); |
198 | } |
199 | |
200 | let headers = Headers { |
201 | stream_id: head.stream_id(), |
202 | stream_dep, |
203 | header_block: HeaderBlock { |
204 | fields: HeaderMap::new(), |
205 | field_size: 0, |
206 | is_over_size: false, |
207 | pseudo: Pseudo::default(), |
208 | }, |
209 | flags, |
210 | }; |
211 | |
212 | Ok((headers, src)) |
213 | } |
214 | |
215 | pub fn load_hpack( |
216 | &mut self, |
217 | src: &mut BytesMut, |
218 | max_header_list_size: usize, |
219 | decoder: &mut hpack::Decoder, |
220 | ) -> Result<(), Error> { |
221 | self.header_block.load(src, max_header_list_size, decoder) |
222 | } |
223 | |
224 | pub fn stream_id(&self) -> StreamId { |
225 | self.stream_id |
226 | } |
227 | |
228 | pub fn is_end_headers(&self) -> bool { |
229 | self.flags.is_end_headers() |
230 | } |
231 | |
232 | pub fn set_end_headers(&mut self) { |
233 | self.flags.set_end_headers(); |
234 | } |
235 | |
236 | pub fn is_end_stream(&self) -> bool { |
237 | self.flags.is_end_stream() |
238 | } |
239 | |
240 | pub fn set_end_stream(&mut self) { |
241 | self.flags.set_end_stream() |
242 | } |
243 | |
244 | pub fn is_over_size(&self) -> bool { |
245 | self.header_block.is_over_size |
246 | } |
247 | |
248 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
249 | (self.header_block.pseudo, self.header_block.fields) |
250 | } |
251 | |
252 | #[cfg (feature = "unstable" )] |
253 | pub fn pseudo_mut(&mut self) -> &mut Pseudo { |
254 | &mut self.header_block.pseudo |
255 | } |
256 | |
257 | pub(crate) fn pseudo(&self) -> &Pseudo { |
258 | &self.header_block.pseudo |
259 | } |
260 | |
261 | /// Whether it has status 1xx |
262 | pub(crate) fn is_informational(&self) -> bool { |
263 | self.header_block.pseudo.is_informational() |
264 | } |
265 | |
266 | pub fn fields(&self) -> &HeaderMap { |
267 | &self.header_block.fields |
268 | } |
269 | |
270 | pub fn into_fields(self) -> HeaderMap { |
271 | self.header_block.fields |
272 | } |
273 | |
274 | pub fn encode( |
275 | self, |
276 | encoder: &mut hpack::Encoder, |
277 | dst: &mut EncodeBuf<'_>, |
278 | ) -> Option<Continuation> { |
279 | // At this point, the `is_end_headers` flag should always be set |
280 | debug_assert!(self.flags.is_end_headers()); |
281 | |
282 | // Get the HEADERS frame head |
283 | let head = self.head(); |
284 | |
285 | self.header_block |
286 | .into_encoding(encoder) |
287 | .encode(&head, dst, |_| {}) |
288 | } |
289 | |
290 | fn head(&self) -> Head { |
291 | Head::new(Kind::Headers, self.flags.into(), self.stream_id) |
292 | } |
293 | } |
294 | |
295 | impl<T> From<Headers> for Frame<T> { |
296 | fn from(src: Headers) -> Self { |
297 | Frame::Headers(src) |
298 | } |
299 | } |
300 | |
301 | impl fmt::Debug for Headers { |
302 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
303 | let mut builder: DebugStruct<'_, '_> = f.debug_struct(name:"Headers" ); |
304 | builder |
305 | .field("stream_id" , &self.stream_id) |
306 | .field(name:"flags" , &self.flags); |
307 | |
308 | if let Some(ref protocol: &Protocol) = self.header_block.pseudo.protocol { |
309 | builder.field(name:"protocol" , value:protocol); |
310 | } |
311 | |
312 | if let Some(ref dep: &StreamDependency) = self.stream_dep { |
313 | builder.field(name:"stream_dep" , value:dep); |
314 | } |
315 | |
316 | // `fields` and `pseudo` purposefully not included |
317 | builder.finish() |
318 | } |
319 | } |
320 | |
321 | // ===== util ===== |
322 | |
323 | #[derive (Debug, PartialEq, Eq)] |
324 | pub struct ParseU64Error; |
325 | |
326 | pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> { |
327 | if src.len() > 19 { |
328 | // At danger for overflow... |
329 | return Err(ParseU64Error); |
330 | } |
331 | |
332 | let mut ret: u64 = 0; |
333 | |
334 | for &d: u8 in src { |
335 | if d < b'0' || d > b'9' { |
336 | return Err(ParseU64Error); |
337 | } |
338 | |
339 | ret *= 10; |
340 | ret += (d - b'0' ) as u64; |
341 | } |
342 | |
343 | Ok(ret) |
344 | } |
345 | |
346 | // ===== impl PushPromise ===== |
347 | |
348 | #[derive (Debug)] |
349 | pub enum PushPromiseHeaderError { |
350 | InvalidContentLength(Result<u64, ParseU64Error>), |
351 | NotSafeAndCacheable, |
352 | } |
353 | |
354 | impl PushPromise { |
355 | pub fn new( |
356 | stream_id: StreamId, |
357 | promised_id: StreamId, |
358 | pseudo: Pseudo, |
359 | fields: HeaderMap, |
360 | ) -> Self { |
361 | PushPromise { |
362 | flags: PushPromiseFlag::default(), |
363 | header_block: HeaderBlock { |
364 | field_size: calculate_headermap_size(&fields), |
365 | fields, |
366 | is_over_size: false, |
367 | pseudo, |
368 | }, |
369 | promised_id, |
370 | stream_id, |
371 | } |
372 | } |
373 | |
374 | pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> { |
375 | use PushPromiseHeaderError::*; |
376 | // The spec has some requirements for promised request headers |
377 | // [https://httpwg.org/specs/rfc7540.html#PushRequests] |
378 | |
379 | // A promised request "that indicates the presence of a request body |
380 | // MUST reset the promised stream with a stream error" |
381 | if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) { |
382 | let parsed_length = parse_u64(content_length.as_bytes()); |
383 | if parsed_length != Ok(0) { |
384 | return Err(InvalidContentLength(parsed_length)); |
385 | } |
386 | } |
387 | // "The server MUST include a method in the :method pseudo-header field |
388 | // that is safe and cacheable" |
389 | if !Self::safe_and_cacheable(req.method()) { |
390 | return Err(NotSafeAndCacheable); |
391 | } |
392 | |
393 | Ok(()) |
394 | } |
395 | |
396 | fn safe_and_cacheable(method: &Method) -> bool { |
397 | // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods |
398 | // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods |
399 | method == Method::GET || method == Method::HEAD |
400 | } |
401 | |
402 | pub fn fields(&self) -> &HeaderMap { |
403 | &self.header_block.fields |
404 | } |
405 | |
406 | #[cfg (feature = "unstable" )] |
407 | pub fn into_fields(self) -> HeaderMap { |
408 | self.header_block.fields |
409 | } |
410 | |
411 | /// Loads the push promise frame but doesn't actually do HPACK decoding. |
412 | /// |
413 | /// HPACK decoding is done in the `load_hpack` step. |
414 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
415 | let flags = PushPromiseFlag(head.flag()); |
416 | let mut pad = 0; |
417 | |
418 | if head.stream_id().is_zero() { |
419 | return Err(Error::InvalidStreamId); |
420 | } |
421 | |
422 | // Read the padding length |
423 | if flags.is_padded() { |
424 | if src.is_empty() { |
425 | return Err(Error::MalformedMessage); |
426 | } |
427 | |
428 | // TODO: Ensure payload is sized correctly |
429 | pad = src[0] as usize; |
430 | |
431 | // Drop the padding |
432 | src.advance(1); |
433 | } |
434 | |
435 | if src.len() < 5 { |
436 | return Err(Error::MalformedMessage); |
437 | } |
438 | |
439 | let (promised_id, _) = StreamId::parse(&src[..4]); |
440 | // Drop promised_id bytes |
441 | src.advance(4); |
442 | |
443 | if pad > 0 { |
444 | if pad > src.len() { |
445 | return Err(Error::TooMuchPadding); |
446 | } |
447 | |
448 | let len = src.len() - pad; |
449 | src.truncate(len); |
450 | } |
451 | |
452 | let frame = PushPromise { |
453 | flags, |
454 | header_block: HeaderBlock { |
455 | fields: HeaderMap::new(), |
456 | field_size: 0, |
457 | is_over_size: false, |
458 | pseudo: Pseudo::default(), |
459 | }, |
460 | promised_id, |
461 | stream_id: head.stream_id(), |
462 | }; |
463 | Ok((frame, src)) |
464 | } |
465 | |
466 | pub fn load_hpack( |
467 | &mut self, |
468 | src: &mut BytesMut, |
469 | max_header_list_size: usize, |
470 | decoder: &mut hpack::Decoder, |
471 | ) -> Result<(), Error> { |
472 | self.header_block.load(src, max_header_list_size, decoder) |
473 | } |
474 | |
475 | pub fn stream_id(&self) -> StreamId { |
476 | self.stream_id |
477 | } |
478 | |
479 | pub fn promised_id(&self) -> StreamId { |
480 | self.promised_id |
481 | } |
482 | |
483 | pub fn is_end_headers(&self) -> bool { |
484 | self.flags.is_end_headers() |
485 | } |
486 | |
487 | pub fn set_end_headers(&mut self) { |
488 | self.flags.set_end_headers(); |
489 | } |
490 | |
491 | pub fn is_over_size(&self) -> bool { |
492 | self.header_block.is_over_size |
493 | } |
494 | |
495 | pub fn encode( |
496 | self, |
497 | encoder: &mut hpack::Encoder, |
498 | dst: &mut EncodeBuf<'_>, |
499 | ) -> Option<Continuation> { |
500 | // At this point, the `is_end_headers` flag should always be set |
501 | debug_assert!(self.flags.is_end_headers()); |
502 | |
503 | let head = self.head(); |
504 | let promised_id = self.promised_id; |
505 | |
506 | self.header_block |
507 | .into_encoding(encoder) |
508 | .encode(&head, dst, |dst| { |
509 | dst.put_u32(promised_id.into()); |
510 | }) |
511 | } |
512 | |
513 | fn head(&self) -> Head { |
514 | Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) |
515 | } |
516 | |
517 | /// Consume `self`, returning the parts of the frame |
518 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
519 | (self.header_block.pseudo, self.header_block.fields) |
520 | } |
521 | } |
522 | |
523 | impl<T> From<PushPromise> for Frame<T> { |
524 | fn from(src: PushPromise) -> Self { |
525 | Frame::PushPromise(src) |
526 | } |
527 | } |
528 | |
529 | impl fmt::Debug for PushPromise { |
530 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
531 | f&mut DebugStruct<'_, '_>.debug_struct("PushPromise" ) |
532 | .field("stream_id" , &self.stream_id) |
533 | .field("promised_id" , &self.promised_id) |
534 | .field(name:"flags" , &self.flags) |
535 | // `fields` and `pseudo` purposefully not included |
536 | .finish() |
537 | } |
538 | } |
539 | |
540 | // ===== impl Continuation ===== |
541 | |
542 | impl Continuation { |
543 | fn head(&self) -> Head { |
544 | Head::new(Kind::Continuation, END_HEADERS, self.stream_id) |
545 | } |
546 | |
547 | pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> { |
548 | // Get the CONTINUATION frame head |
549 | let head: Head = self.head(); |
550 | |
551 | self.header_block.encode(&head, dst, |_| {}) |
552 | } |
553 | } |
554 | |
555 | // ===== impl Pseudo ===== |
556 | |
557 | impl Pseudo { |
558 | pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { |
559 | let parts = uri::Parts::from(uri); |
560 | |
561 | let (scheme, path) = if method == Method::CONNECT && protocol.is_none() { |
562 | (None, None) |
563 | } else { |
564 | let path = parts |
565 | .path_and_query |
566 | .map(|v| BytesStr::from(v.as_str())) |
567 | .unwrap_or(BytesStr::from_static("" )); |
568 | |
569 | let path = if !path.is_empty() { |
570 | path |
571 | } else if method == Method::OPTIONS { |
572 | BytesStr::from_static("*" ) |
573 | } else { |
574 | BytesStr::from_static("/" ) |
575 | }; |
576 | |
577 | (parts.scheme, Some(path)) |
578 | }; |
579 | |
580 | let mut pseudo = Pseudo { |
581 | method: Some(method), |
582 | scheme: None, |
583 | authority: None, |
584 | path, |
585 | protocol, |
586 | status: None, |
587 | }; |
588 | |
589 | // If the URI includes a scheme component, add it to the pseudo headers |
590 | if let Some(scheme) = scheme { |
591 | pseudo.set_scheme(scheme); |
592 | } |
593 | |
594 | // If the URI includes an authority component, add it to the pseudo |
595 | // headers |
596 | if let Some(authority) = parts.authority { |
597 | pseudo.set_authority(BytesStr::from(authority.as_str())); |
598 | } |
599 | |
600 | pseudo |
601 | } |
602 | |
603 | pub fn response(status: StatusCode) -> Self { |
604 | Pseudo { |
605 | method: None, |
606 | scheme: None, |
607 | authority: None, |
608 | path: None, |
609 | protocol: None, |
610 | status: Some(status), |
611 | } |
612 | } |
613 | |
614 | #[cfg (feature = "unstable" )] |
615 | pub fn set_status(&mut self, value: StatusCode) { |
616 | self.status = Some(value); |
617 | } |
618 | |
619 | pub fn set_scheme(&mut self, scheme: uri::Scheme) { |
620 | let bytes_str = match scheme.as_str() { |
621 | "http" => BytesStr::from_static("http" ), |
622 | "https" => BytesStr::from_static("https" ), |
623 | s => BytesStr::from(s), |
624 | }; |
625 | self.scheme = Some(bytes_str); |
626 | } |
627 | |
628 | #[cfg (feature = "unstable" )] |
629 | pub fn set_protocol(&mut self, protocol: Protocol) { |
630 | self.protocol = Some(protocol); |
631 | } |
632 | |
633 | pub fn set_authority(&mut self, authority: BytesStr) { |
634 | self.authority = Some(authority); |
635 | } |
636 | |
637 | /// Whether it has status 1xx |
638 | pub(crate) fn is_informational(&self) -> bool { |
639 | self.status |
640 | .map_or(false, |status| status.is_informational()) |
641 | } |
642 | } |
643 | |
644 | // ===== impl EncodingHeaderBlock ===== |
645 | |
646 | impl EncodingHeaderBlock { |
647 | fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> |
648 | where |
649 | F: FnOnce(&mut EncodeBuf<'_>), |
650 | { |
651 | let head_pos = dst.get_ref().len(); |
652 | |
653 | // At this point, we don't know how big the h2 frame will be. |
654 | // So, we write the head with length 0, then write the body, and |
655 | // finally write the length once we know the size. |
656 | head.encode(0, dst); |
657 | |
658 | let payload_pos = dst.get_ref().len(); |
659 | |
660 | f(dst); |
661 | |
662 | // Now, encode the header payload |
663 | let continuation = if self.hpack.len() > dst.remaining_mut() { |
664 | dst.put((&mut self.hpack).take(dst.remaining_mut())); |
665 | |
666 | Some(Continuation { |
667 | stream_id: head.stream_id(), |
668 | header_block: self, |
669 | }) |
670 | } else { |
671 | dst.put_slice(&self.hpack); |
672 | |
673 | None |
674 | }; |
675 | |
676 | // Compute the header block length |
677 | let payload_len = (dst.get_ref().len() - payload_pos) as u64; |
678 | |
679 | // Write the frame length |
680 | let payload_len_be = payload_len.to_be_bytes(); |
681 | assert!(payload_len_be[0..5].iter().all(|b| *b == 0)); |
682 | (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]); |
683 | |
684 | if continuation.is_some() { |
685 | // There will be continuation frames, so the `is_end_headers` flag |
686 | // must be unset |
687 | debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS); |
688 | |
689 | dst.get_mut()[head_pos + 4] -= END_HEADERS; |
690 | } |
691 | |
692 | continuation |
693 | } |
694 | } |
695 | |
696 | // ===== impl Iter ===== |
697 | |
698 | impl Iterator for Iter { |
699 | type Item = hpack::Header<Option<HeaderName>>; |
700 | |
701 | fn next(&mut self) -> Option<Self::Item> { |
702 | use crate::hpack::Header::*; |
703 | |
704 | if let Some(ref mut pseudo) = self.pseudo { |
705 | if let Some(method) = pseudo.method.take() { |
706 | return Some(Method(method)); |
707 | } |
708 | |
709 | if let Some(scheme) = pseudo.scheme.take() { |
710 | return Some(Scheme(scheme)); |
711 | } |
712 | |
713 | if let Some(authority) = pseudo.authority.take() { |
714 | return Some(Authority(authority)); |
715 | } |
716 | |
717 | if let Some(path) = pseudo.path.take() { |
718 | return Some(Path(path)); |
719 | } |
720 | |
721 | if let Some(protocol) = pseudo.protocol.take() { |
722 | return Some(Protocol(protocol)); |
723 | } |
724 | |
725 | if let Some(status) = pseudo.status.take() { |
726 | return Some(Status(status)); |
727 | } |
728 | } |
729 | |
730 | self.pseudo = None; |
731 | |
732 | self.fields |
733 | .next() |
734 | .map(|(name, value)| Field { name, value }) |
735 | } |
736 | } |
737 | |
738 | // ===== impl HeadersFlag ===== |
739 | |
740 | impl HeadersFlag { |
741 | pub fn empty() -> HeadersFlag { |
742 | HeadersFlag(0) |
743 | } |
744 | |
745 | pub fn load(bits: u8) -> HeadersFlag { |
746 | HeadersFlag(bits & ALL) |
747 | } |
748 | |
749 | pub fn is_end_stream(&self) -> bool { |
750 | self.0 & END_STREAM == END_STREAM |
751 | } |
752 | |
753 | pub fn set_end_stream(&mut self) { |
754 | self.0 |= END_STREAM; |
755 | } |
756 | |
757 | pub fn is_end_headers(&self) -> bool { |
758 | self.0 & END_HEADERS == END_HEADERS |
759 | } |
760 | |
761 | pub fn set_end_headers(&mut self) { |
762 | self.0 |= END_HEADERS; |
763 | } |
764 | |
765 | pub fn is_padded(&self) -> bool { |
766 | self.0 & PADDED == PADDED |
767 | } |
768 | |
769 | pub fn is_priority(&self) -> bool { |
770 | self.0 & PRIORITY == PRIORITY |
771 | } |
772 | } |
773 | |
774 | impl Default for HeadersFlag { |
775 | /// Returns a `HeadersFlag` value with `END_HEADERS` set. |
776 | fn default() -> Self { |
777 | HeadersFlag(END_HEADERS) |
778 | } |
779 | } |
780 | |
781 | impl From<HeadersFlag> for u8 { |
782 | fn from(src: HeadersFlag) -> u8 { |
783 | src.0 |
784 | } |
785 | } |
786 | |
787 | impl fmt::Debug for HeadersFlag { |
788 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
789 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
790 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
791 | .flag_if(self.is_end_stream(), "END_STREAM" ) |
792 | .flag_if(self.is_padded(), "PADDED" ) |
793 | .flag_if(self.is_priority(), name:"PRIORITY" ) |
794 | .finish() |
795 | } |
796 | } |
797 | |
798 | // ===== impl PushPromiseFlag ===== |
799 | |
800 | impl PushPromiseFlag { |
801 | pub fn empty() -> PushPromiseFlag { |
802 | PushPromiseFlag(0) |
803 | } |
804 | |
805 | pub fn load(bits: u8) -> PushPromiseFlag { |
806 | PushPromiseFlag(bits & ALL) |
807 | } |
808 | |
809 | pub fn is_end_headers(&self) -> bool { |
810 | self.0 & END_HEADERS == END_HEADERS |
811 | } |
812 | |
813 | pub fn set_end_headers(&mut self) { |
814 | self.0 |= END_HEADERS; |
815 | } |
816 | |
817 | pub fn is_padded(&self) -> bool { |
818 | self.0 & PADDED == PADDED |
819 | } |
820 | } |
821 | |
822 | impl Default for PushPromiseFlag { |
823 | /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. |
824 | fn default() -> Self { |
825 | PushPromiseFlag(END_HEADERS) |
826 | } |
827 | } |
828 | |
829 | impl From<PushPromiseFlag> for u8 { |
830 | fn from(src: PushPromiseFlag) -> u8 { |
831 | src.0 |
832 | } |
833 | } |
834 | |
835 | impl fmt::Debug for PushPromiseFlag { |
836 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
837 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
838 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
839 | .flag_if(self.is_padded(), name:"PADDED" ) |
840 | .finish() |
841 | } |
842 | } |
843 | |
844 | // ===== HeaderBlock ===== |
845 | |
846 | impl HeaderBlock { |
847 | fn load( |
848 | &mut self, |
849 | src: &mut BytesMut, |
850 | max_header_list_size: usize, |
851 | decoder: &mut hpack::Decoder, |
852 | ) -> Result<(), Error> { |
853 | let mut reg = !self.fields.is_empty(); |
854 | let mut malformed = false; |
855 | let mut headers_size = self.calculate_header_list_size(); |
856 | |
857 | macro_rules! set_pseudo { |
858 | ($field:ident, $val:expr) => {{ |
859 | if reg { |
860 | tracing::trace!("load_hpack; header malformed -- pseudo not at head of block" ); |
861 | malformed = true; |
862 | } else if self.pseudo.$field.is_some() { |
863 | tracing::trace!("load_hpack; header malformed -- repeated pseudo" ); |
864 | malformed = true; |
865 | } else { |
866 | let __val = $val; |
867 | headers_size += |
868 | decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); |
869 | if headers_size < max_header_list_size { |
870 | self.pseudo.$field = Some(__val); |
871 | } else if !self.is_over_size { |
872 | tracing::trace!("load_hpack; header list size over max" ); |
873 | self.is_over_size = true; |
874 | } |
875 | } |
876 | }}; |
877 | } |
878 | |
879 | let mut cursor = Cursor::new(src); |
880 | |
881 | // If the header frame is malformed, we still have to continue decoding |
882 | // the headers. A malformed header frame is a stream level error, but |
883 | // the hpack state is connection level. In order to maintain correct |
884 | // state for other streams, the hpack decoding process must complete. |
885 | let res = decoder.decode(&mut cursor, |header| { |
886 | use crate::hpack::Header::*; |
887 | |
888 | match header { |
889 | Field { name, value } => { |
890 | // Connection level header fields are not supported and must |
891 | // result in a protocol error. |
892 | |
893 | if name == header::CONNECTION |
894 | || name == header::TRANSFER_ENCODING |
895 | || name == header::UPGRADE |
896 | || name == "keep-alive" |
897 | || name == "proxy-connection" |
898 | { |
899 | tracing::trace!("load_hpack; connection level header" ); |
900 | malformed = true; |
901 | } else if name == header::TE && value != "trailers" { |
902 | tracing::trace!( |
903 | "load_hpack; TE header not set to trailers; val= {:?}" , |
904 | value |
905 | ); |
906 | malformed = true; |
907 | } else { |
908 | reg = true; |
909 | |
910 | headers_size += decoded_header_size(name.as_str().len(), value.len()); |
911 | if headers_size < max_header_list_size { |
912 | self.field_size += |
913 | decoded_header_size(name.as_str().len(), value.len()); |
914 | self.fields.append(name, value); |
915 | } else if !self.is_over_size { |
916 | tracing::trace!("load_hpack; header list size over max" ); |
917 | self.is_over_size = true; |
918 | } |
919 | } |
920 | } |
921 | Authority(v) => set_pseudo!(authority, v), |
922 | Method(v) => set_pseudo!(method, v), |
923 | Scheme(v) => set_pseudo!(scheme, v), |
924 | Path(v) => set_pseudo!(path, v), |
925 | Protocol(v) => set_pseudo!(protocol, v), |
926 | Status(v) => set_pseudo!(status, v), |
927 | } |
928 | }); |
929 | |
930 | if let Err(e) = res { |
931 | tracing::trace!("hpack decoding error; err= {:?}" , e); |
932 | return Err(e.into()); |
933 | } |
934 | |
935 | if malformed { |
936 | tracing::trace!("malformed message" ); |
937 | return Err(Error::MalformedMessage); |
938 | } |
939 | |
940 | Ok(()) |
941 | } |
942 | |
943 | fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { |
944 | let mut hpack = BytesMut::new(); |
945 | let headers = Iter { |
946 | pseudo: Some(self.pseudo), |
947 | fields: self.fields.into_iter(), |
948 | }; |
949 | |
950 | encoder.encode(headers, &mut hpack); |
951 | |
952 | EncodingHeaderBlock { |
953 | hpack: hpack.freeze(), |
954 | } |
955 | } |
956 | |
957 | /// Calculates the size of the currently decoded header list. |
958 | /// |
959 | /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE |
960 | /// |
961 | /// > The value is based on the uncompressed size of header fields, |
962 | /// > including the length of the name and value in octets plus an |
963 | /// > overhead of 32 octets for each header field. |
964 | fn calculate_header_list_size(&self) -> usize { |
965 | macro_rules! pseudo_size { |
966 | ($name:ident) => {{ |
967 | self.pseudo |
968 | .$name |
969 | .as_ref() |
970 | .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) |
971 | .unwrap_or(0) |
972 | }}; |
973 | } |
974 | |
975 | pseudo_size!(method) |
976 | + pseudo_size!(scheme) |
977 | + pseudo_size!(status) |
978 | + pseudo_size!(authority) |
979 | + pseudo_size!(path) |
980 | + self.field_size |
981 | } |
982 | } |
983 | |
984 | fn calculate_headermap_size(map: &HeaderMap) -> usize { |
985 | mapimpl Iterator .iter() |
986 | .map(|(name: &HeaderName, value: &HeaderValue)| decoded_header_size(name.as_str().len(), value.len())) |
987 | .sum::<usize>() |
988 | } |
989 | |
990 | fn decoded_header_size(name: usize, value: usize) -> usize { |
991 | name + value + 32 |
992 | } |
993 | |
994 | #[cfg (test)] |
995 | mod test { |
996 | use super::*; |
997 | use crate::frame; |
998 | use crate::hpack::{huffman, Encoder}; |
999 | |
1000 | #[test ] |
1001 | fn test_nameless_header_at_resume() { |
1002 | let mut encoder = Encoder::default(); |
1003 | let mut dst = BytesMut::new(); |
1004 | |
1005 | let headers = Headers::new( |
1006 | StreamId::ZERO, |
1007 | Default::default(), |
1008 | HeaderMap::from_iter(vec![ |
1009 | ( |
1010 | HeaderName::from_static("hello" ), |
1011 | HeaderValue::from_static("world" ), |
1012 | ), |
1013 | ( |
1014 | HeaderName::from_static("hello" ), |
1015 | HeaderValue::from_static("zomg" ), |
1016 | ), |
1017 | ( |
1018 | HeaderName::from_static("hello" ), |
1019 | HeaderValue::from_static("sup" ), |
1020 | ), |
1021 | ]), |
1022 | ); |
1023 | |
1024 | let continuation = headers |
1025 | .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) |
1026 | .unwrap(); |
1027 | |
1028 | assert_eq!(17, dst.len()); |
1029 | assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); |
1030 | assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); |
1031 | assert_eq!("hello" , huff_decode(&dst[11..15])); |
1032 | assert_eq!(0x80 | 4, dst[15]); |
1033 | |
1034 | let mut world = dst[16..17].to_owned(); |
1035 | |
1036 | dst.clear(); |
1037 | |
1038 | assert!(continuation |
1039 | .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) |
1040 | .is_none()); |
1041 | |
1042 | world.extend_from_slice(&dst[9..12]); |
1043 | assert_eq!("world" , huff_decode(&world)); |
1044 | |
1045 | assert_eq!(24, dst.len()); |
1046 | assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); |
1047 | |
1048 | // // Next is not indexed |
1049 | assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); |
1050 | assert_eq!("zomg" , huff_decode(&dst[15..18])); |
1051 | assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); |
1052 | assert_eq!("sup" , huff_decode(&dst[21..])); |
1053 | } |
1054 | |
1055 | fn huff_decode(src: &[u8]) -> BytesMut { |
1056 | let mut buf = BytesMut::new(); |
1057 | huffman::decode(src, &mut buf).unwrap() |
1058 | } |
1059 | |
1060 | #[test ] |
1061 | fn test_connect_request_pseudo_headers_omits_path_and_scheme() { |
1062 | // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields |
1063 | // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5 |
1064 | |
1065 | assert_eq!( |
1066 | Pseudo::request( |
1067 | Method::CONNECT, |
1068 | Uri::from_static("https://example.com:8443" ), |
1069 | None |
1070 | ), |
1071 | Pseudo { |
1072 | method: Method::CONNECT.into(), |
1073 | authority: BytesStr::from_static("example.com:8443" ).into(), |
1074 | ..Default::default() |
1075 | } |
1076 | ); |
1077 | |
1078 | assert_eq!( |
1079 | Pseudo::request( |
1080 | Method::CONNECT, |
1081 | Uri::from_static("https://example.com/test" ), |
1082 | None |
1083 | ), |
1084 | Pseudo { |
1085 | method: Method::CONNECT.into(), |
1086 | authority: BytesStr::from_static("example.com" ).into(), |
1087 | ..Default::default() |
1088 | } |
1089 | ); |
1090 | |
1091 | assert_eq!( |
1092 | Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443" ), None), |
1093 | Pseudo { |
1094 | method: Method::CONNECT.into(), |
1095 | authority: BytesStr::from_static("example.com:8443" ).into(), |
1096 | ..Default::default() |
1097 | } |
1098 | ); |
1099 | } |
1100 | |
1101 | #[test ] |
1102 | fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() { |
1103 | // On requests that contain the :protocol pseudo-header field, the |
1104 | // :scheme and :path pseudo-header fields of the target URI (see |
1105 | // Section 5) MUST also be included. |
1106 | // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4 |
1107 | |
1108 | assert_eq!( |
1109 | Pseudo::request( |
1110 | Method::CONNECT, |
1111 | Uri::from_static("https://example.com:8443" ), |
1112 | Protocol::from_static("the-bread-protocol" ).into() |
1113 | ), |
1114 | Pseudo { |
1115 | method: Method::CONNECT.into(), |
1116 | authority: BytesStr::from_static("example.com:8443" ).into(), |
1117 | scheme: BytesStr::from_static("https" ).into(), |
1118 | path: BytesStr::from_static("/" ).into(), |
1119 | protocol: Protocol::from_static("the-bread-protocol" ).into(), |
1120 | ..Default::default() |
1121 | } |
1122 | ); |
1123 | |
1124 | assert_eq!( |
1125 | Pseudo::request( |
1126 | Method::CONNECT, |
1127 | Uri::from_static("https://example.com:8443/test" ), |
1128 | Protocol::from_static("the-bread-protocol" ).into() |
1129 | ), |
1130 | Pseudo { |
1131 | method: Method::CONNECT.into(), |
1132 | authority: BytesStr::from_static("example.com:8443" ).into(), |
1133 | scheme: BytesStr::from_static("https" ).into(), |
1134 | path: BytesStr::from_static("/test" ).into(), |
1135 | protocol: Protocol::from_static("the-bread-protocol" ).into(), |
1136 | ..Default::default() |
1137 | } |
1138 | ); |
1139 | |
1140 | assert_eq!( |
1141 | Pseudo::request( |
1142 | Method::CONNECT, |
1143 | Uri::from_static("http://example.com/a/b/c" ), |
1144 | Protocol::from_static("the-bread-protocol" ).into() |
1145 | ), |
1146 | Pseudo { |
1147 | method: Method::CONNECT.into(), |
1148 | authority: BytesStr::from_static("example.com" ).into(), |
1149 | scheme: BytesStr::from_static("http" ).into(), |
1150 | path: BytesStr::from_static("/a/b/c" ).into(), |
1151 | protocol: Protocol::from_static("the-bread-protocol" ).into(), |
1152 | ..Default::default() |
1153 | } |
1154 | ); |
1155 | } |
1156 | |
1157 | #[test ] |
1158 | fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() { |
1159 | // an OPTIONS request for an "http" or "https" URI that does not include a path component; |
1160 | // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]). |
1161 | // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1 |
1162 | assert_eq!( |
1163 | Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080" ), None,), |
1164 | Pseudo { |
1165 | method: Method::OPTIONS.into(), |
1166 | authority: BytesStr::from_static("example.com:8080" ).into(), |
1167 | path: BytesStr::from_static("*" ).into(), |
1168 | ..Default::default() |
1169 | } |
1170 | ); |
1171 | } |
1172 | |
1173 | #[test ] |
1174 | fn test_non_option_and_non_connect_requests_include_path_and_scheme() { |
1175 | let methods = [ |
1176 | Method::GET, |
1177 | Method::POST, |
1178 | Method::PUT, |
1179 | Method::DELETE, |
1180 | Method::HEAD, |
1181 | Method::PATCH, |
1182 | Method::TRACE, |
1183 | ]; |
1184 | |
1185 | for method in methods { |
1186 | assert_eq!( |
1187 | Pseudo::request( |
1188 | method.clone(), |
1189 | Uri::from_static("http://example.com:8080" ), |
1190 | None, |
1191 | ), |
1192 | Pseudo { |
1193 | method: method.clone().into(), |
1194 | authority: BytesStr::from_static("example.com:8080" ).into(), |
1195 | scheme: BytesStr::from_static("http" ).into(), |
1196 | path: BytesStr::from_static("/" ).into(), |
1197 | ..Default::default() |
1198 | } |
1199 | ); |
1200 | assert_eq!( |
1201 | Pseudo::request( |
1202 | method.clone(), |
1203 | Uri::from_static("https://example.com/a/b/c" ), |
1204 | None, |
1205 | ), |
1206 | Pseudo { |
1207 | method: method.into(), |
1208 | authority: BytesStr::from_static("example.com" ).into(), |
1209 | scheme: BytesStr::from_static("https" ).into(), |
1210 | path: BytesStr::from_static("/a/b/c" ).into(), |
1211 | ..Default::default() |
1212 | } |
1213 | ); |
1214 | } |
1215 | } |
1216 | } |
1217 | |