1 | use super::{util, StreamDependency, StreamId}; |
2 | use crate::ext::Protocol; |
3 | use crate::frame::{Error, Frame, Head, Kind}; |
4 | use crate::hpack::{self, BytesStr}; |
5 | |
6 | use http::header::{self, HeaderName, HeaderValue}; |
7 | use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; |
8 | |
9 | use bytes::{BufMut, Bytes, BytesMut}; |
10 | |
11 | use std::fmt; |
12 | use std::io::Cursor; |
13 | |
14 | type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; |
15 | |
16 | /// Header frame |
17 | /// |
18 | /// This could be either a request or a response. |
19 | #[derive (Eq, PartialEq)] |
20 | pub struct Headers { |
21 | /// The ID of the stream with which this frame is associated. |
22 | stream_id: StreamId, |
23 | |
24 | /// The stream dependency information, if any. |
25 | stream_dep: Option<StreamDependency>, |
26 | |
27 | /// The header block fragment |
28 | header_block: HeaderBlock, |
29 | |
30 | /// The associated flags |
31 | flags: HeadersFlag, |
32 | } |
33 | |
34 | #[derive (Copy, Clone, Eq, PartialEq)] |
35 | pub struct HeadersFlag(u8); |
36 | |
37 | #[derive (Eq, PartialEq)] |
38 | pub struct PushPromise { |
39 | /// The ID of the stream with which this frame is associated. |
40 | stream_id: StreamId, |
41 | |
42 | /// The ID of the stream being reserved by this PushPromise. |
43 | promised_id: StreamId, |
44 | |
45 | /// The header block fragment |
46 | header_block: HeaderBlock, |
47 | |
48 | /// The associated flags |
49 | flags: PushPromiseFlag, |
50 | } |
51 | |
52 | #[derive (Copy, Clone, Eq, PartialEq)] |
53 | pub struct PushPromiseFlag(u8); |
54 | |
55 | #[derive (Debug)] |
56 | pub struct Continuation { |
57 | /// Stream ID of continuation frame |
58 | stream_id: StreamId, |
59 | |
60 | header_block: EncodingHeaderBlock, |
61 | } |
62 | |
63 | // TODO: These fields shouldn't be `pub` |
64 | #[derive (Debug, Default, Eq, PartialEq)] |
65 | pub struct Pseudo { |
66 | // Request |
67 | pub method: Option<Method>, |
68 | pub scheme: Option<BytesStr>, |
69 | pub authority: Option<BytesStr>, |
70 | pub path: Option<BytesStr>, |
71 | pub protocol: Option<Protocol>, |
72 | |
73 | // Response |
74 | pub status: Option<StatusCode>, |
75 | } |
76 | |
77 | #[derive (Debug)] |
78 | pub struct Iter { |
79 | /// Pseudo headers |
80 | pseudo: Option<Pseudo>, |
81 | |
82 | /// Header fields |
83 | fields: header::IntoIter<HeaderValue>, |
84 | } |
85 | |
86 | #[derive (Debug, PartialEq, Eq)] |
87 | struct HeaderBlock { |
88 | /// The decoded header fields |
89 | fields: HeaderMap, |
90 | |
91 | /// Precomputed size of all of our header fields, for perf reasons |
92 | field_size: usize, |
93 | |
94 | /// Set to true if decoding went over the max header list size. |
95 | is_over_size: bool, |
96 | |
97 | /// Pseudo headers, these are broken out as they must be sent as part of the |
98 | /// headers frame. |
99 | pseudo: Pseudo, |
100 | } |
101 | |
102 | #[derive (Debug)] |
103 | struct EncodingHeaderBlock { |
104 | hpack: Bytes, |
105 | } |
106 | |
107 | const END_STREAM: u8 = 0x1; |
108 | const END_HEADERS: u8 = 0x4; |
109 | const PADDED: u8 = 0x8; |
110 | const PRIORITY: u8 = 0x20; |
111 | const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY; |
112 | |
113 | // ===== impl Headers ===== |
114 | |
115 | impl Headers { |
116 | /// Create a new HEADERS frame |
117 | pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self { |
118 | Headers { |
119 | stream_id, |
120 | stream_dep: None, |
121 | header_block: HeaderBlock { |
122 | field_size: calculate_headermap_size(&fields), |
123 | fields, |
124 | is_over_size: false, |
125 | pseudo, |
126 | }, |
127 | flags: HeadersFlag::default(), |
128 | } |
129 | } |
130 | |
131 | pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self { |
132 | let mut flags = HeadersFlag::default(); |
133 | flags.set_end_stream(); |
134 | |
135 | Headers { |
136 | stream_id, |
137 | stream_dep: None, |
138 | header_block: HeaderBlock { |
139 | field_size: calculate_headermap_size(&fields), |
140 | fields, |
141 | is_over_size: false, |
142 | pseudo: Pseudo::default(), |
143 | }, |
144 | flags, |
145 | } |
146 | } |
147 | |
148 | /// Loads the header frame but doesn't actually do HPACK decoding. |
149 | /// |
150 | /// HPACK decoding is done in the `load_hpack` step. |
151 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
152 | let flags = HeadersFlag(head.flag()); |
153 | let mut pad = 0; |
154 | |
155 | tracing::trace!("loading headers; flags= {:?}" , flags); |
156 | |
157 | if head.stream_id().is_zero() { |
158 | return Err(Error::InvalidStreamId); |
159 | } |
160 | |
161 | // Read the padding length |
162 | if flags.is_padded() { |
163 | if src.is_empty() { |
164 | return Err(Error::MalformedMessage); |
165 | } |
166 | pad = src[0] as usize; |
167 | |
168 | // Drop the padding |
169 | let _ = src.split_to(1); |
170 | } |
171 | |
172 | // Read the stream dependency |
173 | let stream_dep = if flags.is_priority() { |
174 | if src.len() < 5 { |
175 | return Err(Error::MalformedMessage); |
176 | } |
177 | let stream_dep = StreamDependency::load(&src[..5])?; |
178 | |
179 | if stream_dep.dependency_id() == head.stream_id() { |
180 | return Err(Error::InvalidDependencyId); |
181 | } |
182 | |
183 | // Drop the next 5 bytes |
184 | let _ = src.split_to(5); |
185 | |
186 | Some(stream_dep) |
187 | } else { |
188 | None |
189 | }; |
190 | |
191 | if pad > 0 { |
192 | if pad > src.len() { |
193 | return Err(Error::TooMuchPadding); |
194 | } |
195 | |
196 | let len = src.len() - pad; |
197 | src.truncate(len); |
198 | } |
199 | |
200 | let headers = Headers { |
201 | stream_id: head.stream_id(), |
202 | stream_dep, |
203 | header_block: HeaderBlock { |
204 | fields: HeaderMap::new(), |
205 | field_size: 0, |
206 | is_over_size: false, |
207 | pseudo: Pseudo::default(), |
208 | }, |
209 | flags, |
210 | }; |
211 | |
212 | Ok((headers, src)) |
213 | } |
214 | |
215 | pub fn load_hpack( |
216 | &mut self, |
217 | src: &mut BytesMut, |
218 | max_header_list_size: usize, |
219 | decoder: &mut hpack::Decoder, |
220 | ) -> Result<(), Error> { |
221 | self.header_block.load(src, max_header_list_size, decoder) |
222 | } |
223 | |
224 | pub fn stream_id(&self) -> StreamId { |
225 | self.stream_id |
226 | } |
227 | |
228 | pub fn is_end_headers(&self) -> bool { |
229 | self.flags.is_end_headers() |
230 | } |
231 | |
232 | pub fn set_end_headers(&mut self) { |
233 | self.flags.set_end_headers(); |
234 | } |
235 | |
236 | pub fn is_end_stream(&self) -> bool { |
237 | self.flags.is_end_stream() |
238 | } |
239 | |
240 | pub fn set_end_stream(&mut self) { |
241 | self.flags.set_end_stream() |
242 | } |
243 | |
244 | pub fn is_over_size(&self) -> bool { |
245 | self.header_block.is_over_size |
246 | } |
247 | |
248 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
249 | (self.header_block.pseudo, self.header_block.fields) |
250 | } |
251 | |
252 | #[cfg (feature = "unstable" )] |
253 | pub fn pseudo_mut(&mut self) -> &mut Pseudo { |
254 | &mut self.header_block.pseudo |
255 | } |
256 | |
257 | /// Whether it has status 1xx |
258 | pub(crate) fn is_informational(&self) -> bool { |
259 | self.header_block.pseudo.is_informational() |
260 | } |
261 | |
262 | pub fn fields(&self) -> &HeaderMap { |
263 | &self.header_block.fields |
264 | } |
265 | |
266 | pub fn into_fields(self) -> HeaderMap { |
267 | self.header_block.fields |
268 | } |
269 | |
270 | pub fn encode( |
271 | self, |
272 | encoder: &mut hpack::Encoder, |
273 | dst: &mut EncodeBuf<'_>, |
274 | ) -> Option<Continuation> { |
275 | // At this point, the `is_end_headers` flag should always be set |
276 | debug_assert!(self.flags.is_end_headers()); |
277 | |
278 | // Get the HEADERS frame head |
279 | let head = self.head(); |
280 | |
281 | self.header_block |
282 | .into_encoding(encoder) |
283 | .encode(&head, dst, |_| {}) |
284 | } |
285 | |
286 | fn head(&self) -> Head { |
287 | Head::new(Kind::Headers, self.flags.into(), self.stream_id) |
288 | } |
289 | } |
290 | |
291 | impl<T> From<Headers> for Frame<T> { |
292 | fn from(src: Headers) -> Self { |
293 | Frame::Headers(src) |
294 | } |
295 | } |
296 | |
297 | impl fmt::Debug for Headers { |
298 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
299 | let mut builder: DebugStruct<'_, '_> = f.debug_struct(name:"Headers" ); |
300 | builder |
301 | .field("stream_id" , &self.stream_id) |
302 | .field(name:"flags" , &self.flags); |
303 | |
304 | if let Some(ref protocol: &Protocol) = self.header_block.pseudo.protocol { |
305 | builder.field(name:"protocol" , value:protocol); |
306 | } |
307 | |
308 | if let Some(ref dep: &StreamDependency) = self.stream_dep { |
309 | builder.field(name:"stream_dep" , value:dep); |
310 | } |
311 | |
312 | // `fields` and `pseudo` purposefully not included |
313 | builder.finish() |
314 | } |
315 | } |
316 | |
317 | // ===== util ===== |
318 | |
319 | #[derive (Debug, PartialEq, Eq)] |
320 | pub struct ParseU64Error; |
321 | |
322 | pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> { |
323 | if src.len() > 19 { |
324 | // At danger for overflow... |
325 | return Err(ParseU64Error); |
326 | } |
327 | |
328 | let mut ret: u64 = 0; |
329 | |
330 | for &d: u8 in src { |
331 | if d < b'0' || d > b'9' { |
332 | return Err(ParseU64Error); |
333 | } |
334 | |
335 | ret *= 10; |
336 | ret += (d - b'0' ) as u64; |
337 | } |
338 | |
339 | Ok(ret) |
340 | } |
341 | |
342 | // ===== impl PushPromise ===== |
343 | |
344 | #[derive (Debug)] |
345 | pub enum PushPromiseHeaderError { |
346 | InvalidContentLength(Result<u64, ParseU64Error>), |
347 | NotSafeAndCacheable, |
348 | } |
349 | |
350 | impl PushPromise { |
351 | pub fn new( |
352 | stream_id: StreamId, |
353 | promised_id: StreamId, |
354 | pseudo: Pseudo, |
355 | fields: HeaderMap, |
356 | ) -> Self { |
357 | PushPromise { |
358 | flags: PushPromiseFlag::default(), |
359 | header_block: HeaderBlock { |
360 | field_size: calculate_headermap_size(&fields), |
361 | fields, |
362 | is_over_size: false, |
363 | pseudo, |
364 | }, |
365 | promised_id, |
366 | stream_id, |
367 | } |
368 | } |
369 | |
370 | pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> { |
371 | use PushPromiseHeaderError::*; |
372 | // The spec has some requirements for promised request headers |
373 | // [https://httpwg.org/specs/rfc7540.html#PushRequests] |
374 | |
375 | // A promised request "that indicates the presence of a request body |
376 | // MUST reset the promised stream with a stream error" |
377 | if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) { |
378 | let parsed_length = parse_u64(content_length.as_bytes()); |
379 | if parsed_length != Ok(0) { |
380 | return Err(InvalidContentLength(parsed_length)); |
381 | } |
382 | } |
383 | // "The server MUST include a method in the :method pseudo-header field |
384 | // that is safe and cacheable" |
385 | if !Self::safe_and_cacheable(req.method()) { |
386 | return Err(NotSafeAndCacheable); |
387 | } |
388 | |
389 | Ok(()) |
390 | } |
391 | |
392 | fn safe_and_cacheable(method: &Method) -> bool { |
393 | // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods |
394 | // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods |
395 | method == Method::GET || method == Method::HEAD |
396 | } |
397 | |
398 | pub fn fields(&self) -> &HeaderMap { |
399 | &self.header_block.fields |
400 | } |
401 | |
402 | #[cfg (feature = "unstable" )] |
403 | pub fn into_fields(self) -> HeaderMap { |
404 | self.header_block.fields |
405 | } |
406 | |
407 | /// Loads the push promise frame but doesn't actually do HPACK decoding. |
408 | /// |
409 | /// HPACK decoding is done in the `load_hpack` step. |
410 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
411 | let flags = PushPromiseFlag(head.flag()); |
412 | let mut pad = 0; |
413 | |
414 | if head.stream_id().is_zero() { |
415 | return Err(Error::InvalidStreamId); |
416 | } |
417 | |
418 | // Read the padding length |
419 | if flags.is_padded() { |
420 | if src.is_empty() { |
421 | return Err(Error::MalformedMessage); |
422 | } |
423 | |
424 | // TODO: Ensure payload is sized correctly |
425 | pad = src[0] as usize; |
426 | |
427 | // Drop the padding |
428 | let _ = src.split_to(1); |
429 | } |
430 | |
431 | if src.len() < 5 { |
432 | return Err(Error::MalformedMessage); |
433 | } |
434 | |
435 | let (promised_id, _) = StreamId::parse(&src[..4]); |
436 | // Drop promised_id bytes |
437 | let _ = src.split_to(4); |
438 | |
439 | if pad > 0 { |
440 | if pad > src.len() { |
441 | return Err(Error::TooMuchPadding); |
442 | } |
443 | |
444 | let len = src.len() - pad; |
445 | src.truncate(len); |
446 | } |
447 | |
448 | let frame = PushPromise { |
449 | flags, |
450 | header_block: HeaderBlock { |
451 | fields: HeaderMap::new(), |
452 | field_size: 0, |
453 | is_over_size: false, |
454 | pseudo: Pseudo::default(), |
455 | }, |
456 | promised_id, |
457 | stream_id: head.stream_id(), |
458 | }; |
459 | Ok((frame, src)) |
460 | } |
461 | |
462 | pub fn load_hpack( |
463 | &mut self, |
464 | src: &mut BytesMut, |
465 | max_header_list_size: usize, |
466 | decoder: &mut hpack::Decoder, |
467 | ) -> Result<(), Error> { |
468 | self.header_block.load(src, max_header_list_size, decoder) |
469 | } |
470 | |
471 | pub fn stream_id(&self) -> StreamId { |
472 | self.stream_id |
473 | } |
474 | |
475 | pub fn promised_id(&self) -> StreamId { |
476 | self.promised_id |
477 | } |
478 | |
479 | pub fn is_end_headers(&self) -> bool { |
480 | self.flags.is_end_headers() |
481 | } |
482 | |
483 | pub fn set_end_headers(&mut self) { |
484 | self.flags.set_end_headers(); |
485 | } |
486 | |
487 | pub fn is_over_size(&self) -> bool { |
488 | self.header_block.is_over_size |
489 | } |
490 | |
491 | pub fn encode( |
492 | self, |
493 | encoder: &mut hpack::Encoder, |
494 | dst: &mut EncodeBuf<'_>, |
495 | ) -> Option<Continuation> { |
496 | // At this point, the `is_end_headers` flag should always be set |
497 | debug_assert!(self.flags.is_end_headers()); |
498 | |
499 | let head = self.head(); |
500 | let promised_id = self.promised_id; |
501 | |
502 | self.header_block |
503 | .into_encoding(encoder) |
504 | .encode(&head, dst, |dst| { |
505 | dst.put_u32(promised_id.into()); |
506 | }) |
507 | } |
508 | |
509 | fn head(&self) -> Head { |
510 | Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) |
511 | } |
512 | |
513 | /// Consume `self`, returning the parts of the frame |
514 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
515 | (self.header_block.pseudo, self.header_block.fields) |
516 | } |
517 | } |
518 | |
519 | impl<T> From<PushPromise> for Frame<T> { |
520 | fn from(src: PushPromise) -> Self { |
521 | Frame::PushPromise(src) |
522 | } |
523 | } |
524 | |
525 | impl fmt::Debug for PushPromise { |
526 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
527 | f&mut DebugStruct<'_, '_>.debug_struct("PushPromise" ) |
528 | .field("stream_id" , &self.stream_id) |
529 | .field("promised_id" , &self.promised_id) |
530 | .field(name:"flags" , &self.flags) |
531 | // `fields` and `pseudo` purposefully not included |
532 | .finish() |
533 | } |
534 | } |
535 | |
536 | // ===== impl Continuation ===== |
537 | |
538 | impl Continuation { |
539 | fn head(&self) -> Head { |
540 | Head::new(Kind::Continuation, END_HEADERS, self.stream_id) |
541 | } |
542 | |
543 | pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> { |
544 | // Get the CONTINUATION frame head |
545 | let head: Head = self.head(); |
546 | |
547 | self.header_block.encode(&head, dst, |_| {}) |
548 | } |
549 | } |
550 | |
551 | // ===== impl Pseudo ===== |
552 | |
553 | impl Pseudo { |
554 | pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { |
555 | let parts = uri::Parts::from(uri); |
556 | |
557 | let mut path = parts |
558 | .path_and_query |
559 | .map(|v| BytesStr::from(v.as_str())) |
560 | .unwrap_or(BytesStr::from_static("" )); |
561 | |
562 | match method { |
563 | Method::OPTIONS | Method::CONNECT => {} |
564 | _ if path.is_empty() => { |
565 | path = BytesStr::from_static("/" ); |
566 | } |
567 | _ => {} |
568 | } |
569 | |
570 | let mut pseudo = Pseudo { |
571 | method: Some(method), |
572 | scheme: None, |
573 | authority: None, |
574 | path: Some(path).filter(|p| !p.is_empty()), |
575 | protocol, |
576 | status: None, |
577 | }; |
578 | |
579 | // If the URI includes a scheme component, add it to the pseudo headers |
580 | // |
581 | // TODO: Scheme must be set... |
582 | if let Some(scheme) = parts.scheme { |
583 | pseudo.set_scheme(scheme); |
584 | } |
585 | |
586 | // If the URI includes an authority component, add it to the pseudo |
587 | // headers |
588 | if let Some(authority) = parts.authority { |
589 | pseudo.set_authority(BytesStr::from(authority.as_str())); |
590 | } |
591 | |
592 | pseudo |
593 | } |
594 | |
595 | pub fn response(status: StatusCode) -> Self { |
596 | Pseudo { |
597 | method: None, |
598 | scheme: None, |
599 | authority: None, |
600 | path: None, |
601 | protocol: None, |
602 | status: Some(status), |
603 | } |
604 | } |
605 | |
606 | #[cfg (feature = "unstable" )] |
607 | pub fn set_status(&mut self, value: StatusCode) { |
608 | self.status = Some(value); |
609 | } |
610 | |
611 | pub fn set_scheme(&mut self, scheme: uri::Scheme) { |
612 | let bytes_str = match scheme.as_str() { |
613 | "http" => BytesStr::from_static("http" ), |
614 | "https" => BytesStr::from_static("https" ), |
615 | s => BytesStr::from(s), |
616 | }; |
617 | self.scheme = Some(bytes_str); |
618 | } |
619 | |
620 | #[cfg (feature = "unstable" )] |
621 | pub fn set_protocol(&mut self, protocol: Protocol) { |
622 | self.protocol = Some(protocol); |
623 | } |
624 | |
625 | pub fn set_authority(&mut self, authority: BytesStr) { |
626 | self.authority = Some(authority); |
627 | } |
628 | |
629 | /// Whether it has status 1xx |
630 | pub(crate) fn is_informational(&self) -> bool { |
631 | self.status |
632 | .map_or(false, |status| status.is_informational()) |
633 | } |
634 | } |
635 | |
636 | // ===== impl EncodingHeaderBlock ===== |
637 | |
638 | impl EncodingHeaderBlock { |
639 | fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> |
640 | where |
641 | F: FnOnce(&mut EncodeBuf<'_>), |
642 | { |
643 | let head_pos = dst.get_ref().len(); |
644 | |
645 | // At this point, we don't know how big the h2 frame will be. |
646 | // So, we write the head with length 0, then write the body, and |
647 | // finally write the length once we know the size. |
648 | head.encode(0, dst); |
649 | |
650 | let payload_pos = dst.get_ref().len(); |
651 | |
652 | f(dst); |
653 | |
654 | // Now, encode the header payload |
655 | let continuation = if self.hpack.len() > dst.remaining_mut() { |
656 | dst.put_slice(&self.hpack.split_to(dst.remaining_mut())); |
657 | |
658 | Some(Continuation { |
659 | stream_id: head.stream_id(), |
660 | header_block: self, |
661 | }) |
662 | } else { |
663 | dst.put_slice(&self.hpack); |
664 | |
665 | None |
666 | }; |
667 | |
668 | // Compute the header block length |
669 | let payload_len = (dst.get_ref().len() - payload_pos) as u64; |
670 | |
671 | // Write the frame length |
672 | let payload_len_be = payload_len.to_be_bytes(); |
673 | assert!(payload_len_be[0..5].iter().all(|b| *b == 0)); |
674 | (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]); |
675 | |
676 | if continuation.is_some() { |
677 | // There will be continuation frames, so the `is_end_headers` flag |
678 | // must be unset |
679 | debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS); |
680 | |
681 | dst.get_mut()[head_pos + 4] -= END_HEADERS; |
682 | } |
683 | |
684 | continuation |
685 | } |
686 | } |
687 | |
688 | // ===== impl Iter ===== |
689 | |
690 | impl Iterator for Iter { |
691 | type Item = hpack::Header<Option<HeaderName>>; |
692 | |
693 | fn next(&mut self) -> Option<Self::Item> { |
694 | use crate::hpack::Header::*; |
695 | |
696 | if let Some(ref mut pseudo) = self.pseudo { |
697 | if let Some(method) = pseudo.method.take() { |
698 | return Some(Method(method)); |
699 | } |
700 | |
701 | if let Some(scheme) = pseudo.scheme.take() { |
702 | return Some(Scheme(scheme)); |
703 | } |
704 | |
705 | if let Some(authority) = pseudo.authority.take() { |
706 | return Some(Authority(authority)); |
707 | } |
708 | |
709 | if let Some(path) = pseudo.path.take() { |
710 | return Some(Path(path)); |
711 | } |
712 | |
713 | if let Some(protocol) = pseudo.protocol.take() { |
714 | return Some(Protocol(protocol)); |
715 | } |
716 | |
717 | if let Some(status) = pseudo.status.take() { |
718 | return Some(Status(status)); |
719 | } |
720 | } |
721 | |
722 | self.pseudo = None; |
723 | |
724 | self.fields |
725 | .next() |
726 | .map(|(name, value)| Field { name, value }) |
727 | } |
728 | } |
729 | |
730 | // ===== impl HeadersFlag ===== |
731 | |
732 | impl HeadersFlag { |
733 | pub fn empty() -> HeadersFlag { |
734 | HeadersFlag(0) |
735 | } |
736 | |
737 | pub fn load(bits: u8) -> HeadersFlag { |
738 | HeadersFlag(bits & ALL) |
739 | } |
740 | |
741 | pub fn is_end_stream(&self) -> bool { |
742 | self.0 & END_STREAM == END_STREAM |
743 | } |
744 | |
745 | pub fn set_end_stream(&mut self) { |
746 | self.0 |= END_STREAM; |
747 | } |
748 | |
749 | pub fn is_end_headers(&self) -> bool { |
750 | self.0 & END_HEADERS == END_HEADERS |
751 | } |
752 | |
753 | pub fn set_end_headers(&mut self) { |
754 | self.0 |= END_HEADERS; |
755 | } |
756 | |
757 | pub fn is_padded(&self) -> bool { |
758 | self.0 & PADDED == PADDED |
759 | } |
760 | |
761 | pub fn is_priority(&self) -> bool { |
762 | self.0 & PRIORITY == PRIORITY |
763 | } |
764 | } |
765 | |
766 | impl Default for HeadersFlag { |
767 | /// Returns a `HeadersFlag` value with `END_HEADERS` set. |
768 | fn default() -> Self { |
769 | HeadersFlag(END_HEADERS) |
770 | } |
771 | } |
772 | |
773 | impl From<HeadersFlag> for u8 { |
774 | fn from(src: HeadersFlag) -> u8 { |
775 | src.0 |
776 | } |
777 | } |
778 | |
779 | impl fmt::Debug for HeadersFlag { |
780 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
781 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
782 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
783 | .flag_if(self.is_end_stream(), "END_STREAM" ) |
784 | .flag_if(self.is_padded(), "PADDED" ) |
785 | .flag_if(self.is_priority(), name:"PRIORITY" ) |
786 | .finish() |
787 | } |
788 | } |
789 | |
790 | // ===== impl PushPromiseFlag ===== |
791 | |
792 | impl PushPromiseFlag { |
793 | pub fn empty() -> PushPromiseFlag { |
794 | PushPromiseFlag(0) |
795 | } |
796 | |
797 | pub fn load(bits: u8) -> PushPromiseFlag { |
798 | PushPromiseFlag(bits & ALL) |
799 | } |
800 | |
801 | pub fn is_end_headers(&self) -> bool { |
802 | self.0 & END_HEADERS == END_HEADERS |
803 | } |
804 | |
805 | pub fn set_end_headers(&mut self) { |
806 | self.0 |= END_HEADERS; |
807 | } |
808 | |
809 | pub fn is_padded(&self) -> bool { |
810 | self.0 & PADDED == PADDED |
811 | } |
812 | } |
813 | |
814 | impl Default for PushPromiseFlag { |
815 | /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. |
816 | fn default() -> Self { |
817 | PushPromiseFlag(END_HEADERS) |
818 | } |
819 | } |
820 | |
821 | impl From<PushPromiseFlag> for u8 { |
822 | fn from(src: PushPromiseFlag) -> u8 { |
823 | src.0 |
824 | } |
825 | } |
826 | |
827 | impl fmt::Debug for PushPromiseFlag { |
828 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
829 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
830 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
831 | .flag_if(self.is_padded(), name:"PADDED" ) |
832 | .finish() |
833 | } |
834 | } |
835 | |
836 | // ===== HeaderBlock ===== |
837 | |
838 | impl HeaderBlock { |
839 | fn load( |
840 | &mut self, |
841 | src: &mut BytesMut, |
842 | max_header_list_size: usize, |
843 | decoder: &mut hpack::Decoder, |
844 | ) -> Result<(), Error> { |
845 | let mut reg = !self.fields.is_empty(); |
846 | let mut malformed = false; |
847 | let mut headers_size = self.calculate_header_list_size(); |
848 | |
849 | macro_rules! set_pseudo { |
850 | ($field:ident, $val:expr) => {{ |
851 | if reg { |
852 | tracing::trace!("load_hpack; header malformed -- pseudo not at head of block" ); |
853 | malformed = true; |
854 | } else if self.pseudo.$field.is_some() { |
855 | tracing::trace!("load_hpack; header malformed -- repeated pseudo" ); |
856 | malformed = true; |
857 | } else { |
858 | let __val = $val; |
859 | headers_size += |
860 | decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); |
861 | if headers_size < max_header_list_size { |
862 | self.pseudo.$field = Some(__val); |
863 | } else if !self.is_over_size { |
864 | tracing::trace!("load_hpack; header list size over max" ); |
865 | self.is_over_size = true; |
866 | } |
867 | } |
868 | }}; |
869 | } |
870 | |
871 | let mut cursor = Cursor::new(src); |
872 | |
873 | // If the header frame is malformed, we still have to continue decoding |
874 | // the headers. A malformed header frame is a stream level error, but |
875 | // the hpack state is connection level. In order to maintain correct |
876 | // state for other streams, the hpack decoding process must complete. |
877 | let res = decoder.decode(&mut cursor, |header| { |
878 | use crate::hpack::Header::*; |
879 | |
880 | match header { |
881 | Field { name, value } => { |
882 | // Connection level header fields are not supported and must |
883 | // result in a protocol error. |
884 | |
885 | if name == header::CONNECTION |
886 | || name == header::TRANSFER_ENCODING |
887 | || name == header::UPGRADE |
888 | || name == "keep-alive" |
889 | || name == "proxy-connection" |
890 | { |
891 | tracing::trace!("load_hpack; connection level header" ); |
892 | malformed = true; |
893 | } else if name == header::TE && value != "trailers" { |
894 | tracing::trace!( |
895 | "load_hpack; TE header not set to trailers; val= {:?}" , |
896 | value |
897 | ); |
898 | malformed = true; |
899 | } else { |
900 | reg = true; |
901 | |
902 | headers_size += decoded_header_size(name.as_str().len(), value.len()); |
903 | if headers_size < max_header_list_size { |
904 | self.field_size += |
905 | decoded_header_size(name.as_str().len(), value.len()); |
906 | self.fields.append(name, value); |
907 | } else if !self.is_over_size { |
908 | tracing::trace!("load_hpack; header list size over max" ); |
909 | self.is_over_size = true; |
910 | } |
911 | } |
912 | } |
913 | Authority(v) => set_pseudo!(authority, v), |
914 | Method(v) => set_pseudo!(method, v), |
915 | Scheme(v) => set_pseudo!(scheme, v), |
916 | Path(v) => set_pseudo!(path, v), |
917 | Protocol(v) => set_pseudo!(protocol, v), |
918 | Status(v) => set_pseudo!(status, v), |
919 | } |
920 | }); |
921 | |
922 | if let Err(e) = res { |
923 | tracing::trace!("hpack decoding error; err= {:?}" , e); |
924 | return Err(e.into()); |
925 | } |
926 | |
927 | if malformed { |
928 | tracing::trace!("malformed message" ); |
929 | return Err(Error::MalformedMessage); |
930 | } |
931 | |
932 | Ok(()) |
933 | } |
934 | |
935 | fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { |
936 | let mut hpack = BytesMut::new(); |
937 | let headers = Iter { |
938 | pseudo: Some(self.pseudo), |
939 | fields: self.fields.into_iter(), |
940 | }; |
941 | |
942 | encoder.encode(headers, &mut hpack); |
943 | |
944 | EncodingHeaderBlock { |
945 | hpack: hpack.freeze(), |
946 | } |
947 | } |
948 | |
949 | /// Calculates the size of the currently decoded header list. |
950 | /// |
951 | /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE |
952 | /// |
953 | /// > The value is based on the uncompressed size of header fields, |
954 | /// > including the length of the name and value in octets plus an |
955 | /// > overhead of 32 octets for each header field. |
956 | fn calculate_header_list_size(&self) -> usize { |
957 | macro_rules! pseudo_size { |
958 | ($name:ident) => {{ |
959 | self.pseudo |
960 | .$name |
961 | .as_ref() |
962 | .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) |
963 | .unwrap_or(0) |
964 | }}; |
965 | } |
966 | |
967 | pseudo_size!(method) |
968 | + pseudo_size!(scheme) |
969 | + pseudo_size!(status) |
970 | + pseudo_size!(authority) |
971 | + pseudo_size!(path) |
972 | + self.field_size |
973 | } |
974 | } |
975 | |
976 | fn calculate_headermap_size(map: &HeaderMap) -> usize { |
977 | mapimpl Iterator .iter() |
978 | .map(|(name: &HeaderName, value: &HeaderValue)| decoded_header_size(name:name.as_str().len(), value:value.len())) |
979 | .sum::<usize>() |
980 | } |
981 | |
982 | fn decoded_header_size(name: usize, value: usize) -> usize { |
983 | name + value + 32 |
984 | } |
985 | |
986 | #[cfg (test)] |
987 | mod test { |
988 | use std::iter::FromIterator; |
989 | |
990 | use http::HeaderValue; |
991 | |
992 | use super::*; |
993 | use crate::frame; |
994 | use crate::hpack::{huffman, Encoder}; |
995 | |
996 | #[test ] |
997 | fn test_nameless_header_at_resume() { |
998 | let mut encoder = Encoder::default(); |
999 | let mut dst = BytesMut::new(); |
1000 | |
1001 | let headers = Headers::new( |
1002 | StreamId::ZERO, |
1003 | Default::default(), |
1004 | HeaderMap::from_iter(vec![ |
1005 | ( |
1006 | HeaderName::from_static("hello" ), |
1007 | HeaderValue::from_static("world" ), |
1008 | ), |
1009 | ( |
1010 | HeaderName::from_static("hello" ), |
1011 | HeaderValue::from_static("zomg" ), |
1012 | ), |
1013 | ( |
1014 | HeaderName::from_static("hello" ), |
1015 | HeaderValue::from_static("sup" ), |
1016 | ), |
1017 | ]), |
1018 | ); |
1019 | |
1020 | let continuation = headers |
1021 | .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) |
1022 | .unwrap(); |
1023 | |
1024 | assert_eq!(17, dst.len()); |
1025 | assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); |
1026 | assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); |
1027 | assert_eq!("hello" , huff_decode(&dst[11..15])); |
1028 | assert_eq!(0x80 | 4, dst[15]); |
1029 | |
1030 | let mut world = dst[16..17].to_owned(); |
1031 | |
1032 | dst.clear(); |
1033 | |
1034 | assert!(continuation |
1035 | .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) |
1036 | .is_none()); |
1037 | |
1038 | world.extend_from_slice(&dst[9..12]); |
1039 | assert_eq!("world" , huff_decode(&world)); |
1040 | |
1041 | assert_eq!(24, dst.len()); |
1042 | assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); |
1043 | |
1044 | // // Next is not indexed |
1045 | assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); |
1046 | assert_eq!("zomg" , huff_decode(&dst[15..18])); |
1047 | assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); |
1048 | assert_eq!("sup" , huff_decode(&dst[21..])); |
1049 | } |
1050 | |
1051 | fn huff_decode(src: &[u8]) -> BytesMut { |
1052 | let mut buf = BytesMut::new(); |
1053 | huffman::decode(src, &mut buf).unwrap() |
1054 | } |
1055 | } |
1056 | |