1 | use crate::enums::ContentType; |
2 | use crate::enums::ProtocolVersion; |
3 | use crate::msgs::message::{BorrowedPlainMessage, PlainMessage}; |
4 | use crate::Error; |
5 | pub(crate) const MAX_FRAGMENT_LEN: usize = 16384; |
6 | pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2; |
7 | pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD; |
8 | |
9 | pub struct MessageFragmenter { |
10 | max_frag: usize, |
11 | } |
12 | |
13 | impl Default for MessageFragmenter { |
14 | fn default() -> Self { |
15 | Self { |
16 | max_frag: MAX_FRAGMENT_LEN, |
17 | } |
18 | } |
19 | } |
20 | |
21 | impl MessageFragmenter { |
22 | /// Take the Message `msg` and re-fragment it into new |
23 | /// messages whose fragment is no more than max_frag. |
24 | /// Return an iterator across those messages. |
25 | /// Payloads are borrowed. |
26 | pub fn fragment_message<'a>( |
27 | &self, |
28 | msg: &'a PlainMessage, |
29 | ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a { |
30 | self.fragment_slice(msg.typ, msg.version, &msg.payload.0) |
31 | } |
32 | |
33 | /// Enqueue borrowed fragments of (version, typ, payload) which |
34 | /// are no longer than max_frag onto the `out` deque. |
35 | pub(crate) fn fragment_slice<'a>( |
36 | &self, |
37 | typ: ContentType, |
38 | version: ProtocolVersion, |
39 | payload: &'a [u8], |
40 | ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a { |
41 | payload |
42 | .chunks(self.max_frag) |
43 | .map(move |c| BorrowedPlainMessage { |
44 | typ, |
45 | version, |
46 | payload: c, |
47 | }) |
48 | } |
49 | |
50 | /// Set the maximum fragment size that will be produced. |
51 | /// |
52 | /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments |
53 | /// up to 10 bytes long. |
54 | /// |
55 | /// A `max_fragment_size` of `None` sets the highest allowable fragment size. |
56 | /// |
57 | /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389. |
58 | pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> { |
59 | self.max_frag = match max_fragment_size { |
60 | Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD, |
61 | None => MAX_FRAGMENT_LEN, |
62 | _ => return Err(Error::BadMaxFragmentSize), |
63 | }; |
64 | Ok(()) |
65 | } |
66 | } |
67 | |
68 | #[cfg (test)] |
69 | mod tests { |
70 | use super::{MessageFragmenter, PACKET_OVERHEAD}; |
71 | use crate::enums::ContentType; |
72 | use crate::enums::ProtocolVersion; |
73 | use crate::msgs::base::Payload; |
74 | use crate::msgs::message::{BorrowedPlainMessage, PlainMessage}; |
75 | |
76 | fn msg_eq( |
77 | m: &BorrowedPlainMessage, |
78 | total_len: usize, |
79 | typ: &ContentType, |
80 | version: &ProtocolVersion, |
81 | bytes: &[u8], |
82 | ) { |
83 | assert_eq!(&m.typ, typ); |
84 | assert_eq!(&m.version, version); |
85 | assert_eq!(m.payload, bytes); |
86 | |
87 | let buf = m.to_unencrypted_opaque().encode(); |
88 | |
89 | assert_eq!(total_len, buf.len()); |
90 | } |
91 | |
92 | #[test ] |
93 | fn smoke() { |
94 | let typ = ContentType::Handshake; |
95 | let version = ProtocolVersion::TLSv1_2; |
96 | let data: Vec<u8> = (1..70u8).collect(); |
97 | let m = PlainMessage { |
98 | typ, |
99 | version, |
100 | payload: Payload::new(data), |
101 | }; |
102 | |
103 | let mut frag = MessageFragmenter::default(); |
104 | frag.set_max_fragment_size(Some(32)) |
105 | .unwrap(); |
106 | let q = frag |
107 | .fragment_message(&m) |
108 | .collect::<Vec<_>>(); |
109 | assert_eq!(q.len(), 3); |
110 | msg_eq( |
111 | &q[0], |
112 | 32, |
113 | &typ, |
114 | &version, |
115 | &[ |
116 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, |
117 | 24, 25, 26, 27, |
118 | ], |
119 | ); |
120 | msg_eq( |
121 | &q[1], |
122 | 32, |
123 | &typ, |
124 | &version, |
125 | &[ |
126 | 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, |
127 | 49, 50, 51, 52, 53, 54, |
128 | ], |
129 | ); |
130 | msg_eq( |
131 | &q[2], |
132 | 20, |
133 | &typ, |
134 | &version, |
135 | &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69], |
136 | ); |
137 | } |
138 | |
139 | #[test ] |
140 | fn non_fragment() { |
141 | let m = PlainMessage { |
142 | typ: ContentType::Handshake, |
143 | version: ProtocolVersion::TLSv1_2, |
144 | payload: Payload::new(b" \x01\x02\x03\x04\x05\x06\x07\x08" .to_vec()), |
145 | }; |
146 | |
147 | let mut frag = MessageFragmenter::default(); |
148 | frag.set_max_fragment_size(Some(32)) |
149 | .unwrap(); |
150 | let q = frag |
151 | .fragment_message(&m) |
152 | .collect::<Vec<_>>(); |
153 | assert_eq!(q.len(), 1); |
154 | msg_eq( |
155 | &q[0], |
156 | PACKET_OVERHEAD + 8, |
157 | &ContentType::Handshake, |
158 | &ProtocolVersion::TLSv1_2, |
159 | b" \x01\x02\x03\x04\x05\x06\x07\x08" , |
160 | ); |
161 | } |
162 | } |
163 | |