1 | use super::{util, StreamDependency, StreamId}; |
2 | use crate::ext::Protocol; |
3 | use crate::frame::{Error, Frame, Head, Kind}; |
4 | use crate::hpack::{self, BytesStr}; |
5 | |
6 | use http::header::{self, HeaderName, HeaderValue}; |
7 | use http::{uri, HeaderMap, Method, Request, StatusCode, Uri}; |
8 | |
9 | use bytes::{BufMut, Bytes, BytesMut}; |
10 | |
11 | use std::fmt; |
12 | use std::io::Cursor; |
13 | |
14 | type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; |
15 | /// Header frame |
16 | /// |
17 | /// This could be either a request or a response. |
18 | #[derive (Eq, PartialEq)] |
19 | pub struct Headers { |
20 | /// The ID of the stream with which this frame is associated. |
21 | stream_id: StreamId, |
22 | |
23 | /// The stream dependency information, if any. |
24 | stream_dep: Option<StreamDependency>, |
25 | |
26 | /// The header block fragment |
27 | header_block: HeaderBlock, |
28 | |
29 | /// The associated flags |
30 | flags: HeadersFlag, |
31 | } |
32 | |
33 | #[derive (Copy, Clone, Eq, PartialEq)] |
34 | pub struct HeadersFlag(u8); |
35 | |
36 | #[derive (Eq, PartialEq)] |
37 | pub struct PushPromise { |
38 | /// The ID of the stream with which this frame is associated. |
39 | stream_id: StreamId, |
40 | |
41 | /// The ID of the stream being reserved by this PushPromise. |
42 | promised_id: StreamId, |
43 | |
44 | /// The header block fragment |
45 | header_block: HeaderBlock, |
46 | |
47 | /// The associated flags |
48 | flags: PushPromiseFlag, |
49 | } |
50 | |
51 | #[derive (Copy, Clone, Eq, PartialEq)] |
52 | pub struct PushPromiseFlag(u8); |
53 | |
54 | #[derive (Debug)] |
55 | pub struct Continuation { |
56 | /// Stream ID of continuation frame |
57 | stream_id: StreamId, |
58 | |
59 | header_block: EncodingHeaderBlock, |
60 | } |
61 | |
62 | // TODO: These fields shouldn't be `pub` |
63 | #[derive (Debug, Default, Eq, PartialEq)] |
64 | pub struct Pseudo { |
65 | // Request |
66 | pub method: Option<Method>, |
67 | pub scheme: Option<BytesStr>, |
68 | pub authority: Option<BytesStr>, |
69 | pub path: Option<BytesStr>, |
70 | pub protocol: Option<Protocol>, |
71 | |
72 | // Response |
73 | pub status: Option<StatusCode>, |
74 | } |
75 | |
76 | #[derive (Debug)] |
77 | pub struct Iter { |
78 | /// Pseudo headers |
79 | pseudo: Option<Pseudo>, |
80 | |
81 | /// Header fields |
82 | fields: header::IntoIter<HeaderValue>, |
83 | } |
84 | |
85 | #[derive (Debug, PartialEq, Eq)] |
86 | struct HeaderBlock { |
87 | /// The decoded header fields |
88 | fields: HeaderMap, |
89 | |
90 | /// Set to true if decoding went over the max header list size. |
91 | is_over_size: bool, |
92 | |
93 | /// Pseudo headers, these are broken out as they must be sent as part of the |
94 | /// headers frame. |
95 | pseudo: Pseudo, |
96 | } |
97 | |
98 | #[derive (Debug)] |
99 | struct EncodingHeaderBlock { |
100 | hpack: Bytes, |
101 | } |
102 | |
103 | const END_STREAM: u8 = 0x1; |
104 | const END_HEADERS: u8 = 0x4; |
105 | const PADDED: u8 = 0x8; |
106 | const PRIORITY: u8 = 0x20; |
107 | const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY; |
108 | |
109 | // ===== impl Headers ===== |
110 | |
111 | impl Headers { |
112 | /// Create a new HEADERS frame |
113 | pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self { |
114 | Headers { |
115 | stream_id, |
116 | stream_dep: None, |
117 | header_block: HeaderBlock { |
118 | fields, |
119 | is_over_size: false, |
120 | pseudo, |
121 | }, |
122 | flags: HeadersFlag::default(), |
123 | } |
124 | } |
125 | |
126 | pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self { |
127 | let mut flags = HeadersFlag::default(); |
128 | flags.set_end_stream(); |
129 | |
130 | Headers { |
131 | stream_id, |
132 | stream_dep: None, |
133 | header_block: HeaderBlock { |
134 | fields, |
135 | is_over_size: false, |
136 | pseudo: Pseudo::default(), |
137 | }, |
138 | flags, |
139 | } |
140 | } |
141 | |
142 | /// Loads the header frame but doesn't actually do HPACK decoding. |
143 | /// |
144 | /// HPACK decoding is done in the `load_hpack` step. |
145 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
146 | let flags = HeadersFlag(head.flag()); |
147 | let mut pad = 0; |
148 | |
149 | tracing::trace!("loading headers; flags= {:?}" , flags); |
150 | |
151 | if head.stream_id().is_zero() { |
152 | return Err(Error::InvalidStreamId); |
153 | } |
154 | |
155 | // Read the padding length |
156 | if flags.is_padded() { |
157 | if src.is_empty() { |
158 | return Err(Error::MalformedMessage); |
159 | } |
160 | pad = src[0] as usize; |
161 | |
162 | // Drop the padding |
163 | let _ = src.split_to(1); |
164 | } |
165 | |
166 | // Read the stream dependency |
167 | let stream_dep = if flags.is_priority() { |
168 | if src.len() < 5 { |
169 | return Err(Error::MalformedMessage); |
170 | } |
171 | let stream_dep = StreamDependency::load(&src[..5])?; |
172 | |
173 | if stream_dep.dependency_id() == head.stream_id() { |
174 | return Err(Error::InvalidDependencyId); |
175 | } |
176 | |
177 | // Drop the next 5 bytes |
178 | let _ = src.split_to(5); |
179 | |
180 | Some(stream_dep) |
181 | } else { |
182 | None |
183 | }; |
184 | |
185 | if pad > 0 { |
186 | if pad > src.len() { |
187 | return Err(Error::TooMuchPadding); |
188 | } |
189 | |
190 | let len = src.len() - pad; |
191 | src.truncate(len); |
192 | } |
193 | |
194 | let headers = Headers { |
195 | stream_id: head.stream_id(), |
196 | stream_dep, |
197 | header_block: HeaderBlock { |
198 | fields: HeaderMap::new(), |
199 | is_over_size: false, |
200 | pseudo: Pseudo::default(), |
201 | }, |
202 | flags, |
203 | }; |
204 | |
205 | Ok((headers, src)) |
206 | } |
207 | |
208 | pub fn load_hpack( |
209 | &mut self, |
210 | src: &mut BytesMut, |
211 | max_header_list_size: usize, |
212 | decoder: &mut hpack::Decoder, |
213 | ) -> Result<(), Error> { |
214 | self.header_block.load(src, max_header_list_size, decoder) |
215 | } |
216 | |
217 | pub fn stream_id(&self) -> StreamId { |
218 | self.stream_id |
219 | } |
220 | |
221 | pub fn is_end_headers(&self) -> bool { |
222 | self.flags.is_end_headers() |
223 | } |
224 | |
225 | pub fn set_end_headers(&mut self) { |
226 | self.flags.set_end_headers(); |
227 | } |
228 | |
229 | pub fn is_end_stream(&self) -> bool { |
230 | self.flags.is_end_stream() |
231 | } |
232 | |
233 | pub fn set_end_stream(&mut self) { |
234 | self.flags.set_end_stream() |
235 | } |
236 | |
237 | pub fn is_over_size(&self) -> bool { |
238 | self.header_block.is_over_size |
239 | } |
240 | |
241 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
242 | (self.header_block.pseudo, self.header_block.fields) |
243 | } |
244 | |
245 | #[cfg (feature = "unstable" )] |
246 | pub fn pseudo_mut(&mut self) -> &mut Pseudo { |
247 | &mut self.header_block.pseudo |
248 | } |
249 | |
250 | /// Whether it has status 1xx |
251 | pub(crate) fn is_informational(&self) -> bool { |
252 | self.header_block.pseudo.is_informational() |
253 | } |
254 | |
255 | pub fn fields(&self) -> &HeaderMap { |
256 | &self.header_block.fields |
257 | } |
258 | |
259 | pub fn into_fields(self) -> HeaderMap { |
260 | self.header_block.fields |
261 | } |
262 | |
263 | pub fn encode( |
264 | self, |
265 | encoder: &mut hpack::Encoder, |
266 | dst: &mut EncodeBuf<'_>, |
267 | ) -> Option<Continuation> { |
268 | // At this point, the `is_end_headers` flag should always be set |
269 | debug_assert!(self.flags.is_end_headers()); |
270 | |
271 | // Get the HEADERS frame head |
272 | let head = self.head(); |
273 | |
274 | self.header_block |
275 | .into_encoding(encoder) |
276 | .encode(&head, dst, |_| {}) |
277 | } |
278 | |
279 | fn head(&self) -> Head { |
280 | Head::new(Kind::Headers, self.flags.into(), self.stream_id) |
281 | } |
282 | } |
283 | |
284 | impl<T> From<Headers> for Frame<T> { |
285 | fn from(src: Headers) -> Self { |
286 | Frame::Headers(src) |
287 | } |
288 | } |
289 | |
290 | impl fmt::Debug for Headers { |
291 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
292 | let mut builder: DebugStruct<'_, '_> = f.debug_struct(name:"Headers" ); |
293 | builder |
294 | .field("stream_id" , &self.stream_id) |
295 | .field(name:"flags" , &self.flags); |
296 | |
297 | if let Some(ref protocol: &Protocol) = self.header_block.pseudo.protocol { |
298 | builder.field(name:"protocol" , value:protocol); |
299 | } |
300 | |
301 | if let Some(ref dep: &StreamDependency) = self.stream_dep { |
302 | builder.field(name:"stream_dep" , value:dep); |
303 | } |
304 | |
305 | // `fields` and `pseudo` purposefully not included |
306 | builder.finish() |
307 | } |
308 | } |
309 | |
310 | // ===== util ===== |
311 | |
312 | #[derive (Debug, PartialEq, Eq)] |
313 | pub struct ParseU64Error; |
314 | |
315 | pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> { |
316 | if src.len() > 19 { |
317 | // At danger for overflow... |
318 | return Err(ParseU64Error); |
319 | } |
320 | |
321 | let mut ret: u64 = 0; |
322 | |
323 | for &d: u8 in src { |
324 | if d < b'0' || d > b'9' { |
325 | return Err(ParseU64Error); |
326 | } |
327 | |
328 | ret *= 10; |
329 | ret += (d - b'0' ) as u64; |
330 | } |
331 | |
332 | Ok(ret) |
333 | } |
334 | |
335 | // ===== impl PushPromise ===== |
336 | |
337 | #[derive (Debug)] |
338 | pub enum PushPromiseHeaderError { |
339 | InvalidContentLength(Result<u64, ParseU64Error>), |
340 | NotSafeAndCacheable, |
341 | } |
342 | |
343 | impl PushPromise { |
344 | pub fn new( |
345 | stream_id: StreamId, |
346 | promised_id: StreamId, |
347 | pseudo: Pseudo, |
348 | fields: HeaderMap, |
349 | ) -> Self { |
350 | PushPromise { |
351 | flags: PushPromiseFlag::default(), |
352 | header_block: HeaderBlock { |
353 | fields, |
354 | is_over_size: false, |
355 | pseudo, |
356 | }, |
357 | promised_id, |
358 | stream_id, |
359 | } |
360 | } |
361 | |
362 | pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> { |
363 | use PushPromiseHeaderError::*; |
364 | // The spec has some requirements for promised request headers |
365 | // [https://httpwg.org/specs/rfc7540.html#PushRequests] |
366 | |
367 | // A promised request "that indicates the presence of a request body |
368 | // MUST reset the promised stream with a stream error" |
369 | if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) { |
370 | let parsed_length = parse_u64(content_length.as_bytes()); |
371 | if parsed_length != Ok(0) { |
372 | return Err(InvalidContentLength(parsed_length)); |
373 | } |
374 | } |
375 | // "The server MUST include a method in the :method pseudo-header field |
376 | // that is safe and cacheable" |
377 | if !Self::safe_and_cacheable(req.method()) { |
378 | return Err(NotSafeAndCacheable); |
379 | } |
380 | |
381 | Ok(()) |
382 | } |
383 | |
384 | fn safe_and_cacheable(method: &Method) -> bool { |
385 | // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods |
386 | // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods |
387 | method == Method::GET || method == Method::HEAD |
388 | } |
389 | |
390 | pub fn fields(&self) -> &HeaderMap { |
391 | &self.header_block.fields |
392 | } |
393 | |
394 | #[cfg (feature = "unstable" )] |
395 | pub fn into_fields(self) -> HeaderMap { |
396 | self.header_block.fields |
397 | } |
398 | |
399 | /// Loads the push promise frame but doesn't actually do HPACK decoding. |
400 | /// |
401 | /// HPACK decoding is done in the `load_hpack` step. |
402 | pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { |
403 | let flags = PushPromiseFlag(head.flag()); |
404 | let mut pad = 0; |
405 | |
406 | if head.stream_id().is_zero() { |
407 | return Err(Error::InvalidStreamId); |
408 | } |
409 | |
410 | // Read the padding length |
411 | if flags.is_padded() { |
412 | if src.is_empty() { |
413 | return Err(Error::MalformedMessage); |
414 | } |
415 | |
416 | // TODO: Ensure payload is sized correctly |
417 | pad = src[0] as usize; |
418 | |
419 | // Drop the padding |
420 | let _ = src.split_to(1); |
421 | } |
422 | |
423 | if src.len() < 5 { |
424 | return Err(Error::MalformedMessage); |
425 | } |
426 | |
427 | let (promised_id, _) = StreamId::parse(&src[..4]); |
428 | // Drop promised_id bytes |
429 | let _ = src.split_to(4); |
430 | |
431 | if pad > 0 { |
432 | if pad > src.len() { |
433 | return Err(Error::TooMuchPadding); |
434 | } |
435 | |
436 | let len = src.len() - pad; |
437 | src.truncate(len); |
438 | } |
439 | |
440 | let frame = PushPromise { |
441 | flags, |
442 | header_block: HeaderBlock { |
443 | fields: HeaderMap::new(), |
444 | is_over_size: false, |
445 | pseudo: Pseudo::default(), |
446 | }, |
447 | promised_id, |
448 | stream_id: head.stream_id(), |
449 | }; |
450 | Ok((frame, src)) |
451 | } |
452 | |
453 | pub fn load_hpack( |
454 | &mut self, |
455 | src: &mut BytesMut, |
456 | max_header_list_size: usize, |
457 | decoder: &mut hpack::Decoder, |
458 | ) -> Result<(), Error> { |
459 | self.header_block.load(src, max_header_list_size, decoder) |
460 | } |
461 | |
462 | pub fn stream_id(&self) -> StreamId { |
463 | self.stream_id |
464 | } |
465 | |
466 | pub fn promised_id(&self) -> StreamId { |
467 | self.promised_id |
468 | } |
469 | |
470 | pub fn is_end_headers(&self) -> bool { |
471 | self.flags.is_end_headers() |
472 | } |
473 | |
474 | pub fn set_end_headers(&mut self) { |
475 | self.flags.set_end_headers(); |
476 | } |
477 | |
478 | pub fn is_over_size(&self) -> bool { |
479 | self.header_block.is_over_size |
480 | } |
481 | |
482 | pub fn encode( |
483 | self, |
484 | encoder: &mut hpack::Encoder, |
485 | dst: &mut EncodeBuf<'_>, |
486 | ) -> Option<Continuation> { |
487 | // At this point, the `is_end_headers` flag should always be set |
488 | debug_assert!(self.flags.is_end_headers()); |
489 | |
490 | let head = self.head(); |
491 | let promised_id = self.promised_id; |
492 | |
493 | self.header_block |
494 | .into_encoding(encoder) |
495 | .encode(&head, dst, |dst| { |
496 | dst.put_u32(promised_id.into()); |
497 | }) |
498 | } |
499 | |
500 | fn head(&self) -> Head { |
501 | Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) |
502 | } |
503 | |
504 | /// Consume `self`, returning the parts of the frame |
505 | pub fn into_parts(self) -> (Pseudo, HeaderMap) { |
506 | (self.header_block.pseudo, self.header_block.fields) |
507 | } |
508 | } |
509 | |
510 | impl<T> From<PushPromise> for Frame<T> { |
511 | fn from(src: PushPromise) -> Self { |
512 | Frame::PushPromise(src) |
513 | } |
514 | } |
515 | |
516 | impl fmt::Debug for PushPromise { |
517 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
518 | f&mut DebugStruct<'_, '_>.debug_struct("PushPromise" ) |
519 | .field("stream_id" , &self.stream_id) |
520 | .field("promised_id" , &self.promised_id) |
521 | .field(name:"flags" , &self.flags) |
522 | // `fields` and `pseudo` purposefully not included |
523 | .finish() |
524 | } |
525 | } |
526 | |
527 | // ===== impl Continuation ===== |
528 | |
529 | impl Continuation { |
530 | fn head(&self) -> Head { |
531 | Head::new(Kind::Continuation, END_HEADERS, self.stream_id) |
532 | } |
533 | |
534 | pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> { |
535 | // Get the CONTINUATION frame head |
536 | let head: Head = self.head(); |
537 | |
538 | self.header_block.encode(&head, dst, |_| {}) |
539 | } |
540 | } |
541 | |
542 | // ===== impl Pseudo ===== |
543 | |
544 | impl Pseudo { |
545 | pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { |
546 | let parts = uri::Parts::from(uri); |
547 | |
548 | let mut path = parts |
549 | .path_and_query |
550 | .map(|v| BytesStr::from(v.as_str())) |
551 | .unwrap_or(BytesStr::from_static("" )); |
552 | |
553 | match method { |
554 | Method::OPTIONS | Method::CONNECT => {} |
555 | _ if path.is_empty() => { |
556 | path = BytesStr::from_static("/" ); |
557 | } |
558 | _ => {} |
559 | } |
560 | |
561 | let mut pseudo = Pseudo { |
562 | method: Some(method), |
563 | scheme: None, |
564 | authority: None, |
565 | path: Some(path).filter(|p| !p.is_empty()), |
566 | protocol, |
567 | status: None, |
568 | }; |
569 | |
570 | // If the URI includes a scheme component, add it to the pseudo headers |
571 | // |
572 | // TODO: Scheme must be set... |
573 | if let Some(scheme) = parts.scheme { |
574 | pseudo.set_scheme(scheme); |
575 | } |
576 | |
577 | // If the URI includes an authority component, add it to the pseudo |
578 | // headers |
579 | if let Some(authority) = parts.authority { |
580 | pseudo.set_authority(BytesStr::from(authority.as_str())); |
581 | } |
582 | |
583 | pseudo |
584 | } |
585 | |
586 | pub fn response(status: StatusCode) -> Self { |
587 | Pseudo { |
588 | method: None, |
589 | scheme: None, |
590 | authority: None, |
591 | path: None, |
592 | protocol: None, |
593 | status: Some(status), |
594 | } |
595 | } |
596 | |
597 | #[cfg (feature = "unstable" )] |
598 | pub fn set_status(&mut self, value: StatusCode) { |
599 | self.status = Some(value); |
600 | } |
601 | |
602 | pub fn set_scheme(&mut self, scheme: uri::Scheme) { |
603 | let bytes_str = match scheme.as_str() { |
604 | "http" => BytesStr::from_static("http" ), |
605 | "https" => BytesStr::from_static("https" ), |
606 | s => BytesStr::from(s), |
607 | }; |
608 | self.scheme = Some(bytes_str); |
609 | } |
610 | |
611 | #[cfg (feature = "unstable" )] |
612 | pub fn set_protocol(&mut self, protocol: Protocol) { |
613 | self.protocol = Some(protocol); |
614 | } |
615 | |
616 | pub fn set_authority(&mut self, authority: BytesStr) { |
617 | self.authority = Some(authority); |
618 | } |
619 | |
620 | /// Whether it has status 1xx |
621 | pub(crate) fn is_informational(&self) -> bool { |
622 | self.status |
623 | .map_or(false, |status| status.is_informational()) |
624 | } |
625 | } |
626 | |
627 | // ===== impl EncodingHeaderBlock ===== |
628 | |
629 | impl EncodingHeaderBlock { |
630 | fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> |
631 | where |
632 | F: FnOnce(&mut EncodeBuf<'_>), |
633 | { |
634 | let head_pos = dst.get_ref().len(); |
635 | |
636 | // At this point, we don't know how big the h2 frame will be. |
637 | // So, we write the head with length 0, then write the body, and |
638 | // finally write the length once we know the size. |
639 | head.encode(0, dst); |
640 | |
641 | let payload_pos = dst.get_ref().len(); |
642 | |
643 | f(dst); |
644 | |
645 | // Now, encode the header payload |
646 | let continuation = if self.hpack.len() > dst.remaining_mut() { |
647 | dst.put_slice(&self.hpack.split_to(dst.remaining_mut())); |
648 | |
649 | Some(Continuation { |
650 | stream_id: head.stream_id(), |
651 | header_block: self, |
652 | }) |
653 | } else { |
654 | dst.put_slice(&self.hpack); |
655 | |
656 | None |
657 | }; |
658 | |
659 | // Compute the header block length |
660 | let payload_len = (dst.get_ref().len() - payload_pos) as u64; |
661 | |
662 | // Write the frame length |
663 | let payload_len_be = payload_len.to_be_bytes(); |
664 | assert!(payload_len_be[0..5].iter().all(|b| *b == 0)); |
665 | (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]); |
666 | |
667 | if continuation.is_some() { |
668 | // There will be continuation frames, so the `is_end_headers` flag |
669 | // must be unset |
670 | debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS); |
671 | |
672 | dst.get_mut()[head_pos + 4] -= END_HEADERS; |
673 | } |
674 | |
675 | continuation |
676 | } |
677 | } |
678 | |
679 | // ===== impl Iter ===== |
680 | |
681 | impl Iterator for Iter { |
682 | type Item = hpack::Header<Option<HeaderName>>; |
683 | |
684 | fn next(&mut self) -> Option<Self::Item> { |
685 | use crate::hpack::Header::*; |
686 | |
687 | if let Some(ref mut pseudo) = self.pseudo { |
688 | if let Some(method) = pseudo.method.take() { |
689 | return Some(Method(method)); |
690 | } |
691 | |
692 | if let Some(scheme) = pseudo.scheme.take() { |
693 | return Some(Scheme(scheme)); |
694 | } |
695 | |
696 | if let Some(authority) = pseudo.authority.take() { |
697 | return Some(Authority(authority)); |
698 | } |
699 | |
700 | if let Some(path) = pseudo.path.take() { |
701 | return Some(Path(path)); |
702 | } |
703 | |
704 | if let Some(protocol) = pseudo.protocol.take() { |
705 | return Some(Protocol(protocol)); |
706 | } |
707 | |
708 | if let Some(status) = pseudo.status.take() { |
709 | return Some(Status(status)); |
710 | } |
711 | } |
712 | |
713 | self.pseudo = None; |
714 | |
715 | self.fields |
716 | .next() |
717 | .map(|(name, value)| Field { name, value }) |
718 | } |
719 | } |
720 | |
721 | // ===== impl HeadersFlag ===== |
722 | |
723 | impl HeadersFlag { |
724 | pub fn empty() -> HeadersFlag { |
725 | HeadersFlag(0) |
726 | } |
727 | |
728 | pub fn load(bits: u8) -> HeadersFlag { |
729 | HeadersFlag(bits & ALL) |
730 | } |
731 | |
732 | pub fn is_end_stream(&self) -> bool { |
733 | self.0 & END_STREAM == END_STREAM |
734 | } |
735 | |
736 | pub fn set_end_stream(&mut self) { |
737 | self.0 |= END_STREAM; |
738 | } |
739 | |
740 | pub fn is_end_headers(&self) -> bool { |
741 | self.0 & END_HEADERS == END_HEADERS |
742 | } |
743 | |
744 | pub fn set_end_headers(&mut self) { |
745 | self.0 |= END_HEADERS; |
746 | } |
747 | |
748 | pub fn is_padded(&self) -> bool { |
749 | self.0 & PADDED == PADDED |
750 | } |
751 | |
752 | pub fn is_priority(&self) -> bool { |
753 | self.0 & PRIORITY == PRIORITY |
754 | } |
755 | } |
756 | |
757 | impl Default for HeadersFlag { |
758 | /// Returns a `HeadersFlag` value with `END_HEADERS` set. |
759 | fn default() -> Self { |
760 | HeadersFlag(END_HEADERS) |
761 | } |
762 | } |
763 | |
764 | impl From<HeadersFlag> for u8 { |
765 | fn from(src: HeadersFlag) -> u8 { |
766 | src.0 |
767 | } |
768 | } |
769 | |
770 | impl fmt::Debug for HeadersFlag { |
771 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
772 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
773 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
774 | .flag_if(self.is_end_stream(), "END_STREAM" ) |
775 | .flag_if(self.is_padded(), "PADDED" ) |
776 | .flag_if(self.is_priority(), name:"PRIORITY" ) |
777 | .finish() |
778 | } |
779 | } |
780 | |
781 | // ===== impl PushPromiseFlag ===== |
782 | |
783 | impl PushPromiseFlag { |
784 | pub fn empty() -> PushPromiseFlag { |
785 | PushPromiseFlag(0) |
786 | } |
787 | |
788 | pub fn load(bits: u8) -> PushPromiseFlag { |
789 | PushPromiseFlag(bits & ALL) |
790 | } |
791 | |
792 | pub fn is_end_headers(&self) -> bool { |
793 | self.0 & END_HEADERS == END_HEADERS |
794 | } |
795 | |
796 | pub fn set_end_headers(&mut self) { |
797 | self.0 |= END_HEADERS; |
798 | } |
799 | |
800 | pub fn is_padded(&self) -> bool { |
801 | self.0 & PADDED == PADDED |
802 | } |
803 | } |
804 | |
805 | impl Default for PushPromiseFlag { |
806 | /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. |
807 | fn default() -> Self { |
808 | PushPromiseFlag(END_HEADERS) |
809 | } |
810 | } |
811 | |
812 | impl From<PushPromiseFlag> for u8 { |
813 | fn from(src: PushPromiseFlag) -> u8 { |
814 | src.0 |
815 | } |
816 | } |
817 | |
818 | impl fmt::Debug for PushPromiseFlag { |
819 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
820 | util&mut DebugFlags<'_, '_>::debug_flags(fmt, self.0) |
821 | .flag_if(self.is_end_headers(), "END_HEADERS" ) |
822 | .flag_if(self.is_padded(), name:"PADDED" ) |
823 | .finish() |
824 | } |
825 | } |
826 | |
827 | // ===== HeaderBlock ===== |
828 | |
829 | impl HeaderBlock { |
830 | fn load( |
831 | &mut self, |
832 | src: &mut BytesMut, |
833 | max_header_list_size: usize, |
834 | decoder: &mut hpack::Decoder, |
835 | ) -> Result<(), Error> { |
836 | let mut reg = !self.fields.is_empty(); |
837 | let mut malformed = false; |
838 | let mut headers_size = self.calculate_header_list_size(); |
839 | |
840 | macro_rules! set_pseudo { |
841 | ($field:ident, $val:expr) => {{ |
842 | if reg { |
843 | tracing::trace!("load_hpack; header malformed -- pseudo not at head of block" ); |
844 | malformed = true; |
845 | } else if self.pseudo.$field.is_some() { |
846 | tracing::trace!("load_hpack; header malformed -- repeated pseudo" ); |
847 | malformed = true; |
848 | } else { |
849 | let __val = $val; |
850 | headers_size += |
851 | decoded_header_size(stringify!($field).len() + 1, __val.as_str().len()); |
852 | if headers_size < max_header_list_size { |
853 | self.pseudo.$field = Some(__val); |
854 | } else if !self.is_over_size { |
855 | tracing::trace!("load_hpack; header list size over max" ); |
856 | self.is_over_size = true; |
857 | } |
858 | } |
859 | }}; |
860 | } |
861 | |
862 | let mut cursor = Cursor::new(src); |
863 | |
864 | // If the header frame is malformed, we still have to continue decoding |
865 | // the headers. A malformed header frame is a stream level error, but |
866 | // the hpack state is connection level. In order to maintain correct |
867 | // state for other streams, the hpack decoding process must complete. |
868 | let res = decoder.decode(&mut cursor, |header| { |
869 | use crate::hpack::Header::*; |
870 | |
871 | match header { |
872 | Field { name, value } => { |
873 | // Connection level header fields are not supported and must |
874 | // result in a protocol error. |
875 | |
876 | if name == header::CONNECTION |
877 | || name == header::TRANSFER_ENCODING |
878 | || name == header::UPGRADE |
879 | || name == "keep-alive" |
880 | || name == "proxy-connection" |
881 | { |
882 | tracing::trace!("load_hpack; connection level header" ); |
883 | malformed = true; |
884 | } else if name == header::TE && value != "trailers" { |
885 | tracing::trace!( |
886 | "load_hpack; TE header not set to trailers; val= {:?}" , |
887 | value |
888 | ); |
889 | malformed = true; |
890 | } else { |
891 | reg = true; |
892 | |
893 | headers_size += decoded_header_size(name.as_str().len(), value.len()); |
894 | if headers_size < max_header_list_size { |
895 | self.fields.append(name, value); |
896 | } else if !self.is_over_size { |
897 | tracing::trace!("load_hpack; header list size over max" ); |
898 | self.is_over_size = true; |
899 | } |
900 | } |
901 | } |
902 | Authority(v) => set_pseudo!(authority, v), |
903 | Method(v) => set_pseudo!(method, v), |
904 | Scheme(v) => set_pseudo!(scheme, v), |
905 | Path(v) => set_pseudo!(path, v), |
906 | Protocol(v) => set_pseudo!(protocol, v), |
907 | Status(v) => set_pseudo!(status, v), |
908 | } |
909 | }); |
910 | |
911 | if let Err(e) = res { |
912 | tracing::trace!("hpack decoding error; err= {:?}" , e); |
913 | return Err(e.into()); |
914 | } |
915 | |
916 | if malformed { |
917 | tracing::trace!("malformed message" ); |
918 | return Err(Error::MalformedMessage); |
919 | } |
920 | |
921 | Ok(()) |
922 | } |
923 | |
924 | fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { |
925 | let mut hpack = BytesMut::new(); |
926 | let headers = Iter { |
927 | pseudo: Some(self.pseudo), |
928 | fields: self.fields.into_iter(), |
929 | }; |
930 | |
931 | encoder.encode(headers, &mut hpack); |
932 | |
933 | EncodingHeaderBlock { |
934 | hpack: hpack.freeze(), |
935 | } |
936 | } |
937 | |
938 | /// Calculates the size of the currently decoded header list. |
939 | /// |
940 | /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE |
941 | /// |
942 | /// > The value is based on the uncompressed size of header fields, |
943 | /// > including the length of the name and value in octets plus an |
944 | /// > overhead of 32 octets for each header field. |
945 | fn calculate_header_list_size(&self) -> usize { |
946 | macro_rules! pseudo_size { |
947 | ($name:ident) => {{ |
948 | self.pseudo |
949 | .$name |
950 | .as_ref() |
951 | .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) |
952 | .unwrap_or(0) |
953 | }}; |
954 | } |
955 | |
956 | pseudo_size!(method) |
957 | + pseudo_size!(scheme) |
958 | + pseudo_size!(status) |
959 | + pseudo_size!(authority) |
960 | + pseudo_size!(path) |
961 | + self |
962 | .fields |
963 | .iter() |
964 | .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) |
965 | .sum::<usize>() |
966 | } |
967 | } |
968 | |
969 | fn decoded_header_size(name: usize, value: usize) -> usize { |
970 | name + value + 32 |
971 | } |
972 | |
973 | #[cfg (test)] |
974 | mod test { |
975 | use std::iter::FromIterator; |
976 | |
977 | use http::HeaderValue; |
978 | |
979 | use super::*; |
980 | use crate::frame; |
981 | use crate::hpack::{huffman, Encoder}; |
982 | |
983 | #[test ] |
984 | fn test_nameless_header_at_resume() { |
985 | let mut encoder = Encoder::default(); |
986 | let mut dst = BytesMut::new(); |
987 | |
988 | let headers = Headers::new( |
989 | StreamId::ZERO, |
990 | Default::default(), |
991 | HeaderMap::from_iter(vec![ |
992 | ( |
993 | HeaderName::from_static("hello" ), |
994 | HeaderValue::from_static("world" ), |
995 | ), |
996 | ( |
997 | HeaderName::from_static("hello" ), |
998 | HeaderValue::from_static("zomg" ), |
999 | ), |
1000 | ( |
1001 | HeaderName::from_static("hello" ), |
1002 | HeaderValue::from_static("sup" ), |
1003 | ), |
1004 | ]), |
1005 | ); |
1006 | |
1007 | let continuation = headers |
1008 | .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8)) |
1009 | .unwrap(); |
1010 | |
1011 | assert_eq!(17, dst.len()); |
1012 | assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]); |
1013 | assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]); |
1014 | assert_eq!("hello" , huff_decode(&dst[11..15])); |
1015 | assert_eq!(0x80 | 4, dst[15]); |
1016 | |
1017 | let mut world = dst[16..17].to_owned(); |
1018 | |
1019 | dst.clear(); |
1020 | |
1021 | assert!(continuation |
1022 | .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16)) |
1023 | .is_none()); |
1024 | |
1025 | world.extend_from_slice(&dst[9..12]); |
1026 | assert_eq!("world" , huff_decode(&world)); |
1027 | |
1028 | assert_eq!(24, dst.len()); |
1029 | assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]); |
1030 | |
1031 | // // Next is not indexed |
1032 | assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]); |
1033 | assert_eq!("zomg" , huff_decode(&dst[15..18])); |
1034 | assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]); |
1035 | assert_eq!("sup" , huff_decode(&dst[21..])); |
1036 | } |
1037 | |
1038 | fn huff_decode(src: &[u8]) -> BytesMut { |
1039 | let mut buf = BytesMut::new(); |
1040 | huffman::decode(src, &mut buf).unwrap() |
1041 | } |
1042 | } |
1043 | |