1use alloc::vec::Vec;
2use core::ops::Range;
3use core::slice::SliceIndex;
4use std::io;
5
6use super::base::Payload;
7use super::codec::Codec;
8use super::message::PlainMessage;
9use crate::enums::{ContentType, ProtocolVersion};
10use crate::error::{Error, InvalidMessage, PeerMisbehaved};
11use crate::msgs::codec;
12use crate::msgs::message::{MessageError, OpaqueMessage};
13use crate::record_layer::{Decrypted, RecordLayer};
14
15/// This deframer works to reconstruct TLS messages from a stream of arbitrary-sized reads.
16///
17/// It buffers incoming data into a `Vec` through `read()`, and returns messages through `pop()`.
18/// QUIC connections will call `push()` to append handshake payload data directly.
19#[derive(Default)]
20pub struct MessageDeframer {
21 /// Set if the peer is not talking TLS, but some other
22 /// protocol. The caller should abort the connection, because
23 /// the deframer cannot recover.
24 last_error: Option<Error>,
25
26 /// If we're in the middle of joining a handshake payload, this is the metadata.
27 joining_hs: Option<HandshakePayloadMeta>,
28}
29
30impl MessageDeframer {
31 /// Return any decrypted messages that the deframer has been able to parse.
32 ///
33 /// Returns an `Error` if the deframer failed to parse some message contents or if decryption
34 /// failed, `Ok(None)` if no full message is buffered or if trial decryption failed, and
35 /// `Ok(Some(_))` if a valid message was found and decrypted successfully.
36 pub fn pop(
37 &mut self,
38 record_layer: &mut RecordLayer,
39 negotiated_version: Option<ProtocolVersion>,
40 buffer: &mut DeframerSliceBuffer,
41 ) -> Result<Option<Deframed>, Error> {
42 if let Some(last_err) = self.last_error.clone() {
43 return Err(last_err);
44 } else if buffer.is_empty() {
45 return Ok(None);
46 }
47
48 // We loop over records we've received but not processed yet.
49 // For records that decrypt as `Handshake`, we keep the current state of the joined
50 // handshake message payload in `self.joining_hs`, appending to it as we see records.
51 let expected_len = loop {
52 let start = match &self.joining_hs {
53 Some(meta) => {
54 match meta.expected_len {
55 // We're joining a handshake payload, and we've seen the full payload.
56 Some(len) if len <= meta.payload.len() => break len,
57 // Not enough data, and we can't parse any more out of the buffer (QUIC).
58 _ if meta.quic => return Ok(None),
59 // Try parsing some more of the encrypted buffered data.
60 _ => meta.message.end,
61 }
62 }
63 None => 0,
64 };
65
66 // Does our `buf` contain a full message? It does if it is big enough to
67 // contain a header, and that header has a length which falls within `buf`.
68 // If so, deframe it and place the message onto the frames output queue.
69 let mut rd = codec::Reader::init(buffer.filled_get(start..));
70 let m = match OpaqueMessage::read(&mut rd) {
71 Ok(m) => m,
72 Err(msg_err) => {
73 let err_kind = match msg_err {
74 MessageError::TooShortForHeader | MessageError::TooShortForLength => {
75 return Ok(None)
76 }
77 MessageError::InvalidEmptyPayload => InvalidMessage::InvalidEmptyPayload,
78 MessageError::MessageTooLarge => InvalidMessage::MessageTooLarge,
79 MessageError::InvalidContentType => InvalidMessage::InvalidContentType,
80 MessageError::UnknownProtocolVersion => {
81 InvalidMessage::UnknownProtocolVersion
82 }
83 };
84
85 return Err(self.set_err(err_kind));
86 }
87 };
88
89 // Return CCS messages and early plaintext alerts immediately without decrypting.
90 let end = start + rd.used();
91 let version_is_tls13 = matches!(negotiated_version, Some(ProtocolVersion::TLSv1_3));
92 let allowed_plaintext = match m.typ {
93 // CCS messages are always plaintext.
94 ContentType::ChangeCipherSpec => true,
95 // Alerts are allowed to be plaintext if-and-only-if:
96 // * The negotiated protocol version is TLS 1.3. - In TLS 1.2 it is unambiguous when
97 // keying changes based on the CCS message. Only TLS 1.3 requires these heuristics.
98 // * We have not yet decrypted any messages from the peer - if we have we don't
99 // expect any plaintext.
100 // * The payload size is indicative of a plaintext alert message.
101 ContentType::Alert
102 if version_is_tls13
103 && !record_layer.has_decrypted()
104 && m.payload().len() <= 2 =>
105 {
106 true
107 }
108 // In other circumstances, we expect all messages to be encrypted.
109 _ => false,
110 };
111 if self.joining_hs.is_none() && allowed_plaintext {
112 // This is unencrypted. We check the contents later.
113 buffer.queue_discard(end);
114 return Ok(Some(Deframed {
115 want_close_before_decrypt: false,
116 aligned: true,
117 trial_decryption_finished: false,
118 message: m.into_plain_message(),
119 }));
120 }
121
122 // Decrypt the encrypted message (if necessary).
123 let msg = match record_layer.decrypt_incoming(m) {
124 Ok(Some(decrypted)) => {
125 let Decrypted {
126 want_close_before_decrypt,
127 plaintext,
128 } = decrypted;
129 debug_assert!(!want_close_before_decrypt);
130 plaintext
131 }
132 // This was rejected early data, discard it. If we currently have a handshake
133 // payload in progress, this counts as interleaved, so we error out.
134 Ok(None) if self.joining_hs.is_some() => {
135 return Err(self.set_err(
136 PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage,
137 ));
138 }
139 Ok(None) => {
140 buffer.queue_discard(end);
141 continue;
142 }
143 Err(e) => return Err(e),
144 };
145
146 if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
147 // "Handshake messages MUST NOT be interleaved with other record
148 // types. That is, if a handshake message is split over two or more
149 // records, there MUST NOT be any other records between them."
150 // https://www.rfc-editor.org/rfc/rfc8446#section-5.1
151 return Err(self.set_err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage));
152 }
153
154 // If it's not a handshake message, just return it -- no joining necessary.
155 if msg.typ != ContentType::Handshake {
156 let end = start + rd.used();
157 buffer.queue_discard(end);
158 return Ok(Some(Deframed {
159 want_close_before_decrypt: false,
160 aligned: true,
161 trial_decryption_finished: false,
162 message: msg,
163 }));
164 }
165
166 // If we don't know the payload size yet or if the payload size is larger
167 // than the currently buffered payload, we need to wait for more data.
168 match self.append_hs::<_, false>(msg.version, &msg.payload.0, end, buffer)? {
169 HandshakePayloadState::Blocked => return Ok(None),
170 HandshakePayloadState::Complete(len) => break len,
171 HandshakePayloadState::Continue => continue,
172 }
173 };
174
175 let meta = self.joining_hs.as_mut().unwrap(); // safe after calling `append_hs()`
176
177 // We can now wrap the complete handshake payload in a `PlainMessage`, to be returned.
178 let message = PlainMessage {
179 typ: ContentType::Handshake,
180 version: meta.version,
181 payload: Payload::new(
182 buffer.filled_get(meta.payload.start..meta.payload.start + expected_len),
183 ),
184 };
185
186 // But before we return, update the `joining_hs` state to skip past this payload.
187 if meta.payload.len() > expected_len {
188 // If we have another (beginning of) a handshake payload left in the buffer, update
189 // the payload start to point past the payload we're about to yield, and update the
190 // `expected_len` to match the state of that remaining payload.
191 meta.payload.start += expected_len;
192 meta.expected_len =
193 payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
194 } else {
195 // Otherwise, we've yielded the last handshake payload in the buffer, so we can
196 // discard all of the bytes that we're previously buffered as handshake data.
197 let end = meta.message.end;
198 self.joining_hs = None;
199 buffer.queue_discard(end);
200 }
201
202 Ok(Some(Deframed {
203 want_close_before_decrypt: false,
204 aligned: self.joining_hs.is_none(),
205 trial_decryption_finished: true,
206 message,
207 }))
208 }
209
210 /// Fuses this deframer's error and returns the set value.
211 ///
212 /// Any future calls to `pop` will return `err` again.
213 fn set_err(&mut self, err: impl Into<Error>) -> Error {
214 let err = err.into();
215 self.last_error = Some(err.clone());
216 err
217 }
218
219 /// Allow pushing handshake messages directly into the buffer.
220 pub(crate) fn push(
221 &mut self,
222 version: ProtocolVersion,
223 payload: &[u8],
224 buffer: &mut DeframerVecBuffer,
225 ) -> Result<(), Error> {
226 if !buffer.is_empty() && self.joining_hs.is_none() {
227 return Err(Error::General(
228 "cannot push QUIC messages into unrelated connection".into(),
229 ));
230 } else if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
231 return Err(Error::General(err.into()));
232 }
233
234 let end = buffer.len() + payload.len();
235 self.append_hs::<_, true>(version, payload, end, buffer)?;
236 Ok(())
237 }
238
239 /// Write the handshake message contents into the buffer and update the metadata.
240 ///
241 /// Returns true if a complete message is found.
242 fn append_hs<T: DeframerBuffer<QUIC>, const QUIC: bool>(
243 &mut self,
244 version: ProtocolVersion,
245 payload: &[u8],
246 end: usize,
247 buffer: &mut T,
248 ) -> Result<HandshakePayloadState, Error> {
249 let meta = match &mut self.joining_hs {
250 Some(meta) => {
251 debug_assert_eq!(meta.quic, QUIC);
252
253 // We're joining a handshake message to the previous one here.
254 // Write it into the buffer and update the metadata.
255
256 DeframerBuffer::<QUIC>::copy(buffer, payload, meta.payload.end);
257 meta.message.end = end;
258 meta.payload.end += payload.len();
259
260 // If we haven't parsed the payload size yet, try to do so now.
261 if meta.expected_len.is_none() {
262 meta.expected_len =
263 payload_size(buffer.filled_get(meta.payload.start..meta.payload.end))?;
264 }
265
266 meta
267 }
268 None => {
269 // We've found a new handshake message here.
270 // Write it into the buffer and create the metadata.
271
272 let expected_len = payload_size(payload)?;
273 DeframerBuffer::<QUIC>::copy(buffer, payload, 0);
274 self.joining_hs
275 .insert(HandshakePayloadMeta {
276 message: Range { start: 0, end },
277 payload: Range {
278 start: 0,
279 end: payload.len(),
280 },
281 version,
282 expected_len,
283 quic: QUIC,
284 })
285 }
286 };
287
288 Ok(match meta.expected_len {
289 Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
290 _ => match buffer.len() > meta.message.end {
291 true => HandshakePayloadState::Continue,
292 false => HandshakePayloadState::Blocked,
293 },
294 })
295 }
296
297 /// Read some bytes from `rd`, and add them to our internal buffer.
298 #[allow(clippy::comparison_chain)]
299 pub fn read(
300 &mut self,
301 rd: &mut dyn io::Read,
302 buffer: &mut DeframerVecBuffer,
303 ) -> io::Result<usize> {
304 if let Err(err) = buffer.prepare_read(self.joining_hs.is_some()) {
305 return Err(io::Error::new(io::ErrorKind::InvalidData, err));
306 }
307
308 // Try to do the largest reads possible. Note that if
309 // we get a message with a length field out of range here,
310 // we do a zero length read. That looks like an EOF to
311 // the next layer up, which is fine.
312 let new_bytes = rd.read(buffer.unfilled())?;
313 buffer.advance(new_bytes);
314 Ok(new_bytes)
315 }
316}
317
318#[derive(Default, Debug)]
319pub struct DeframerVecBuffer {
320 /// Buffer of data read from the socket, in the process of being parsed into messages.
321 ///
322 /// For buffer size management, checkout out the [`DeframerVecBuffer::prepare_read()`] method.
323 buf: Vec<u8>,
324
325 /// What size prefix of `buf` is used.
326 used: usize,
327}
328
329impl DeframerVecBuffer {
330 /// Borrows the initialized contents of this buffer and tracks pending discard operations via
331 /// the `discard` reference
332 pub fn borrow(&mut self) -> DeframerSliceBuffer {
333 DeframerSliceBuffer::new(&mut self.buf[..self.used])
334 }
335
336 /// Returns true if there are messages for the caller to process
337 pub fn has_pending(&self) -> bool {
338 !self.is_empty()
339 }
340
341 /// Resize the internal `buf` if necessary for reading more bytes.
342 fn prepare_read(&mut self, is_joining_hs: bool) -> Result<(), &'static str> {
343 // We allow a maximum of 64k of buffered data for handshake messages only. Enforce this
344 // by varying the maximum allowed buffer size here based on whether a prefix of a
345 // handshake payload is currently being buffered. Given that the first read of such a
346 // payload will only ever be 4k bytes, the next time we come around here we allow a
347 // larger buffer size. Once the large message and any following handshake messages in
348 // the same flight have been consumed, `pop()` will call `discard()` to reset `used`.
349 // At this point, the buffer resizing logic below should reduce the buffer size.
350 let allow_max = match is_joining_hs {
351 true => MAX_HANDSHAKE_SIZE as usize,
352 false => OpaqueMessage::MAX_WIRE_SIZE,
353 };
354
355 if self.used >= allow_max {
356 return Err("message buffer full");
357 }
358
359 // If we can and need to increase the buffer size to allow a 4k read, do so. After
360 // dealing with a large handshake message (exceeding `OpaqueMessage::MAX_WIRE_SIZE`),
361 // make sure to reduce the buffer size again (large messages should be rare).
362 // Also, reduce the buffer size if there are neither full nor partial messages in it,
363 // which usually means that the other side suspended sending data.
364 let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
365 if need_capacity > self.buf.len() {
366 self.buf.resize(need_capacity, 0);
367 } else if self.used == 0 || self.buf.len() > allow_max {
368 self.buf.resize(need_capacity, 0);
369 self.buf.shrink_to(need_capacity);
370 }
371
372 Ok(())
373 }
374
375 /// Discard `taken` bytes from the start of our buffer.
376 pub fn discard(&mut self, taken: usize) {
377 #[allow(clippy::comparison_chain)]
378 if taken < self.used {
379 /* Before:
380 * +----------+----------+----------+
381 * | taken | pending |xxxxxxxxxx|
382 * +----------+----------+----------+
383 * 0 ^ taken ^ self.used
384 *
385 * After:
386 * +----------+----------+----------+
387 * | pending |xxxxxxxxxxxxxxxxxxxxx|
388 * +----------+----------+----------+
389 * 0 ^ self.used
390 */
391
392 self.buf
393 .copy_within(taken..self.used, 0);
394 self.used -= taken;
395 } else if taken == self.used {
396 self.used = 0;
397 }
398 }
399
400 fn is_empty(&self) -> bool {
401 self.len() == 0
402 }
403
404 fn advance(&mut self, num_bytes: usize) {
405 self.used += num_bytes;
406 }
407
408 fn unfilled(&mut self) -> &mut [u8] {
409 &mut self.buf[self.used..]
410 }
411}
412
413impl FilledDeframerBuffer for DeframerVecBuffer {
414 fn filled_mut(&mut self) -> &mut [u8] {
415 &mut self.buf[..self.used]
416 }
417
418 fn filled(&self) -> &[u8] {
419 &self.buf[..self.used]
420 }
421}
422
423impl DeframerBuffer<true> for DeframerVecBuffer {
424 fn copy(&mut self, src: &[u8], at: usize) {
425 copy_into_buffer(self.unfilled(), src, at);
426 self.advance(num_bytes:src.len());
427 }
428}
429
430impl DeframerBuffer<false> for DeframerVecBuffer {
431 fn copy(&mut self, src: &[u8], at: usize) {
432 self.borrow().copy(src, at)
433 }
434}
435
436/// A borrowed version of [`DeframerVecBuffer`] that tracks discard operations
437pub struct DeframerSliceBuffer<'a> {
438 // a fully initialized buffer that will be deframed
439 buf: &'a mut [u8],
440 // number of bytes to discard from the front of `buf` at a later time
441 discard: usize,
442}
443
444impl<'a> DeframerSliceBuffer<'a> {
445 pub fn new(buf: &'a mut [u8]) -> Self {
446 Self { buf, discard: 0 }
447 }
448
449 /// Tracks a pending discard operation of `num_bytes`
450 pub fn queue_discard(&mut self, num_bytes: usize) {
451 self.discard += num_bytes;
452 }
453
454 /// Returns the number of bytes that need to be discarded
455 pub fn pending_discard(&self) -> usize {
456 self.discard
457 }
458
459 pub fn is_empty(&self) -> bool {
460 self.len() == 0
461 }
462}
463
464impl FilledDeframerBuffer for DeframerSliceBuffer<'_> {
465 fn filled_mut(&mut self) -> &mut [u8] {
466 &mut self.buf[self.discard..]
467 }
468
469 fn filled(&self) -> &[u8] {
470 &self.buf[self.discard..]
471 }
472}
473
474impl DeframerBuffer<false> for DeframerSliceBuffer<'_> {
475 fn copy(&mut self, src: &[u8], at: usize) {
476 copy_into_buffer(self.filled_mut(), src, at)
477 }
478}
479
480trait DeframerBuffer<const QUIC: bool>: FilledDeframerBuffer {
481 /// Copies from the `src` buffer into this buffer at the requested index
482 ///
483 /// If `QUIC` is true the data will be copied into the *un*filled section of the buffer
484 ///
485 /// If `QUIC` is false the data will be copied into the filled section of the buffer
486 fn copy(&mut self, src: &[u8], at: usize);
487}
488
489fn copy_into_buffer(buf: &mut [u8], src: &[u8], at: usize) {
490 buf[at..at + src.len()].copy_from_slice(src);
491}
492
493trait FilledDeframerBuffer {
494 fn filled_mut(&mut self) -> &mut [u8];
495
496 fn filled_get<I>(&self, index: I) -> &I::Output
497 where
498 I: SliceIndex<[u8]>,
499 {
500 self.filled().get(index).unwrap()
501 }
502
503 fn len(&self) -> usize {
504 self.filled().len()
505 }
506
507 fn filled(&self) -> &[u8];
508}
509
510enum HandshakePayloadState {
511 /// Waiting for more data.
512 Blocked,
513 /// We have a complete handshake message.
514 Complete(usize),
515 /// More records available for processing.
516 Continue,
517}
518
519struct HandshakePayloadMeta {
520 /// The range of bytes from the deframer buffer that contains data processed so far.
521 ///
522 /// This will need to be discarded as the last of the handshake message is `pop()`ped.
523 message: Range<usize>,
524 /// The range of bytes from the deframer buffer that contains payload.
525 payload: Range<usize>,
526 /// The protocol version as found in the decrypted handshake message.
527 version: ProtocolVersion,
528 /// The expected size of the handshake payload, if available.
529 ///
530 /// If the received payload exceeds 4 bytes (the handshake payload header), we update
531 /// `expected_len` to contain the payload length as advertised (at most 16_777_215 bytes).
532 expected_len: Option<usize>,
533 /// True if this is a QUIC handshake message.
534 ///
535 /// In the case of QUIC, we get a plaintext handshake data directly from the CRYPTO stream,
536 /// so there's no need to unwrap and decrypt the outer TLS record. This is implemented
537 /// by directly calling `MessageDeframer::push()` from the connection.
538 quic: bool,
539}
540
541/// Determine the expected length of the payload as advertised in the header.
542///
543/// Returns `Err` if the advertised length is larger than what we want to accept
544/// (`MAX_HANDSHAKE_SIZE`), `Ok(None)` if the buffer is too small to contain a complete header,
545/// and `Ok(Some(len))` otherwise.
546fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
547 if buf.len() < HEADER_SIZE {
548 return Ok(None);
549 }
550
551 let (header: &[u8], _) = buf.split_at(HEADER_SIZE);
552 match codec::u24::read_bytes(&header[1..]) {
553 Ok(len: u24) if len.0 > MAX_HANDSHAKE_SIZE => Err(Error::InvalidMessage(
554 InvalidMessage::HandshakePayloadTooLarge,
555 )),
556 Ok(len: u24) => Ok(Some(HEADER_SIZE + usize::from(len))),
557 _ => Ok(None),
558 }
559}
560
561#[derive(Debug)]
562pub struct Deframed {
563 pub(crate) want_close_before_decrypt: bool,
564 pub(crate) aligned: bool,
565 pub(crate) trial_decryption_finished: bool,
566 pub message: PlainMessage,
567}
568
569const HEADER_SIZE: usize = 1 + 3;
570
571/// TLS allows for handshake messages of up to 16MB. We
572/// restrict that to 64KB to limit potential for denial-of-
573/// service.
574const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
575
576const READ_SIZE: usize = 4096;
577
578#[cfg(test)]
579mod tests {
580 use std::io;
581
582 use crate::msgs::message::Message;
583
584 use super::*;
585
586 #[test]
587 fn check_incremental() {
588 let mut d = BufferedDeframer::default();
589 assert!(!d.has_pending());
590 input_whole_incremental(&mut d, FIRST_MESSAGE);
591 assert!(d.has_pending());
592
593 let mut rl = RecordLayer::new();
594 pop_first(&mut d, &mut rl);
595 assert!(!d.has_pending());
596 assert!(d.last_error.is_none());
597 }
598
599 #[test]
600 fn check_incremental_2() {
601 let mut d = BufferedDeframer::default();
602 assert!(!d.has_pending());
603 input_whole_incremental(&mut d, FIRST_MESSAGE);
604 assert!(d.has_pending());
605 input_whole_incremental(&mut d, SECOND_MESSAGE);
606 assert!(d.has_pending());
607
608 let mut rl = RecordLayer::new();
609 pop_first(&mut d, &mut rl);
610 assert!(d.has_pending());
611 pop_second(&mut d, &mut rl);
612 assert!(!d.has_pending());
613 assert!(d.last_error.is_none());
614 }
615
616 #[test]
617 fn check_whole() {
618 let mut d = BufferedDeframer::default();
619 assert!(!d.has_pending());
620 assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
621 assert!(d.has_pending());
622
623 let mut rl = RecordLayer::new();
624 pop_first(&mut d, &mut rl);
625 assert!(!d.has_pending());
626 assert!(d.last_error.is_none());
627 }
628
629 #[test]
630 fn check_whole_2() {
631 let mut d = BufferedDeframer::default();
632 assert!(!d.has_pending());
633 assert_len(FIRST_MESSAGE.len(), d.input_bytes(FIRST_MESSAGE));
634 assert_len(SECOND_MESSAGE.len(), d.input_bytes(SECOND_MESSAGE));
635
636 let mut rl = RecordLayer::new();
637 pop_first(&mut d, &mut rl);
638 pop_second(&mut d, &mut rl);
639 assert!(!d.has_pending());
640 assert!(d.last_error.is_none());
641 }
642
643 #[test]
644 fn test_two_in_one_read() {
645 let mut d = BufferedDeframer::default();
646 assert!(!d.has_pending());
647 assert_len(
648 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
649 d.input_bytes_concat(FIRST_MESSAGE, SECOND_MESSAGE),
650 );
651
652 let mut rl = RecordLayer::new();
653 pop_first(&mut d, &mut rl);
654 pop_second(&mut d, &mut rl);
655 assert!(!d.has_pending());
656 assert!(d.last_error.is_none());
657 }
658
659 #[test]
660 fn test_two_in_one_read_shortest_first() {
661 let mut d = BufferedDeframer::default();
662 assert!(!d.has_pending());
663 assert_len(
664 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
665 d.input_bytes_concat(SECOND_MESSAGE, FIRST_MESSAGE),
666 );
667
668 let mut rl = RecordLayer::new();
669 pop_second(&mut d, &mut rl);
670 pop_first(&mut d, &mut rl);
671 assert!(!d.has_pending());
672 assert!(d.last_error.is_none());
673 }
674
675 #[test]
676 fn test_incremental_with_nonfatal_read_error() {
677 let mut d = BufferedDeframer::default();
678 assert_len(3, d.input_bytes(&FIRST_MESSAGE[..3]));
679 input_error(&mut d);
680 assert_len(FIRST_MESSAGE.len() - 3, d.input_bytes(&FIRST_MESSAGE[3..]));
681
682 let mut rl = RecordLayer::new();
683 pop_first(&mut d, &mut rl);
684 assert!(!d.has_pending());
685 assert!(d.last_error.is_none());
686 }
687
688 #[test]
689 fn test_invalid_contenttype_errors() {
690 let mut d = BufferedDeframer::default();
691 assert_len(
692 INVALID_CONTENTTYPE_MESSAGE.len(),
693 d.input_bytes(INVALID_CONTENTTYPE_MESSAGE),
694 );
695
696 let mut rl = RecordLayer::new();
697 assert_eq!(
698 d.pop(&mut rl, None).unwrap_err(),
699 Error::InvalidMessage(InvalidMessage::InvalidContentType)
700 );
701 }
702
703 #[test]
704 fn test_invalid_version_errors() {
705 let mut d = BufferedDeframer::default();
706 assert_len(
707 INVALID_VERSION_MESSAGE.len(),
708 d.input_bytes(INVALID_VERSION_MESSAGE),
709 );
710
711 let mut rl = RecordLayer::new();
712 assert_eq!(
713 d.pop(&mut rl, None).unwrap_err(),
714 Error::InvalidMessage(InvalidMessage::UnknownProtocolVersion)
715 );
716 }
717
718 #[test]
719 fn test_invalid_length_errors() {
720 let mut d = BufferedDeframer::default();
721 assert_len(
722 INVALID_LENGTH_MESSAGE.len(),
723 d.input_bytes(INVALID_LENGTH_MESSAGE),
724 );
725
726 let mut rl = RecordLayer::new();
727 assert_eq!(
728 d.pop(&mut rl, None).unwrap_err(),
729 Error::InvalidMessage(InvalidMessage::MessageTooLarge)
730 );
731 }
732
733 #[test]
734 fn test_empty_applicationdata() {
735 let mut d = BufferedDeframer::default();
736 assert_len(
737 EMPTY_APPLICATIONDATA_MESSAGE.len(),
738 d.input_bytes(EMPTY_APPLICATIONDATA_MESSAGE),
739 );
740
741 let mut rl = RecordLayer::new();
742 let m = d
743 .pop(&mut rl, None)
744 .unwrap()
745 .unwrap()
746 .message;
747 assert_eq!(m.typ, ContentType::ApplicationData);
748 assert_eq!(m.payload.0.len(), 0);
749 assert!(!d.has_pending());
750 assert!(d.last_error.is_none());
751 }
752
753 #[test]
754 fn test_invalid_empty_errors() {
755 let mut d = BufferedDeframer::default();
756 assert_len(
757 INVALID_EMPTY_MESSAGE.len(),
758 d.input_bytes(INVALID_EMPTY_MESSAGE),
759 );
760
761 let mut rl = RecordLayer::new();
762 assert_eq!(
763 d.pop(&mut rl, None).unwrap_err(),
764 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
765 );
766 // CorruptMessage has been fused
767 assert_eq!(
768 d.pop(&mut rl, None).unwrap_err(),
769 Error::InvalidMessage(InvalidMessage::InvalidEmptyPayload)
770 );
771 }
772
773 #[test]
774 fn test_limited_buffer() {
775 const PAYLOAD_LEN: usize = 16_384;
776 let mut message = Vec::with_capacity(16_389);
777 message.push(0x17); // ApplicationData
778 message.extend(&[0x03, 0x04]); // ProtocolVersion
779 message.extend((PAYLOAD_LEN as u16).to_be_bytes()); // payload length
780 message.extend(&[0; PAYLOAD_LEN]);
781
782 let mut d = BufferedDeframer::default();
783 assert_len(4096, d.input_bytes(&message));
784 assert_len(4096, d.input_bytes(&message));
785 assert_len(4096, d.input_bytes(&message));
786 assert_len(4096, d.input_bytes(&message));
787 assert_len(
788 OpaqueMessage::MAX_WIRE_SIZE - 16_384,
789 d.input_bytes(&message),
790 );
791 assert!(d.input_bytes(&message).is_err());
792 }
793
794 fn input_error(d: &mut BufferedDeframer) {
795 let error = io::Error::from(io::ErrorKind::TimedOut);
796 let mut rd = ErrorRead::new(error);
797 d.read(&mut rd)
798 .expect_err("error not propagated");
799 }
800
801 fn input_whole_incremental(d: &mut BufferedDeframer, bytes: &[u8]) {
802 let before = d.buffer.len();
803
804 for i in 0..bytes.len() {
805 assert_len(1, d.input_bytes(&bytes[i..i + 1]));
806 assert!(d.has_pending());
807 }
808
809 assert_eq!(before + bytes.len(), d.buffer.len());
810 }
811
812 fn pop_first(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
813 let m = d
814 .pop(rl, None)
815 .unwrap()
816 .unwrap()
817 .message;
818 assert_eq!(m.typ, ContentType::Handshake);
819 Message::try_from(m).unwrap();
820 }
821
822 fn pop_second(d: &mut BufferedDeframer, rl: &mut RecordLayer) {
823 let m = d
824 .pop(rl, None)
825 .unwrap()
826 .unwrap()
827 .message;
828 assert_eq!(m.typ, ContentType::Alert);
829 Message::try_from(m).unwrap();
830 }
831
832 // buffered version to ease testing
833 #[derive(Default)]
834 struct BufferedDeframer {
835 inner: MessageDeframer,
836 buffer: DeframerVecBuffer,
837 }
838
839 impl BufferedDeframer {
840 fn input_bytes(&mut self, bytes: &[u8]) -> io::Result<usize> {
841 let mut rd = io::Cursor::new(bytes);
842 self.read(&mut rd)
843 }
844
845 fn input_bytes_concat(&mut self, bytes1: &[u8], bytes2: &[u8]) -> io::Result<usize> {
846 let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
847 bytes[..bytes1.len()].clone_from_slice(bytes1);
848 bytes[bytes1.len()..].clone_from_slice(bytes2);
849 let mut rd = io::Cursor::new(&bytes);
850 self.read(&mut rd)
851 }
852
853 fn pop(
854 &mut self,
855 record_layer: &mut RecordLayer,
856 negotiated_version: Option<ProtocolVersion>,
857 ) -> Result<Option<Deframed>, Error> {
858 let mut deframer_buffer = self.buffer.borrow();
859 let res = self
860 .inner
861 .pop(record_layer, negotiated_version, &mut deframer_buffer);
862 let discard = deframer_buffer.pending_discard();
863 self.buffer.discard(discard);
864 res
865 }
866
867 fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
868 self.inner.read(rd, &mut self.buffer)
869 }
870
871 fn has_pending(&self) -> bool {
872 self.buffer.has_pending()
873 }
874 }
875
876 // grant access to the `MessageDeframer.last_error` field
877 impl core::ops::Deref for BufferedDeframer {
878 type Target = MessageDeframer;
879
880 fn deref(&self) -> &Self::Target {
881 &self.inner
882 }
883 }
884
885 struct ErrorRead {
886 error: Option<io::Error>,
887 }
888
889 impl ErrorRead {
890 fn new(error: io::Error) -> Self {
891 Self { error: Some(error) }
892 }
893 }
894
895 impl io::Read for ErrorRead {
896 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
897 for (i, b) in buf.iter_mut().enumerate() {
898 *b = i as u8;
899 }
900
901 let error = self.error.take().unwrap();
902 Err(error)
903 }
904 }
905
906 fn assert_len(want: usize, got: io::Result<usize>) {
907 assert_eq!(Some(want), got.ok())
908 }
909
910 const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
911 const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
912
913 const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
914 include_bytes!("../testdata/deframer-empty-applicationdata.bin");
915
916 const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
917 const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
918 include_bytes!("../testdata/deframer-invalid-contenttype.bin");
919 const INVALID_VERSION_MESSAGE: &[u8] =
920 include_bytes!("../testdata/deframer-invalid-version.bin");
921 const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
922}
923