1 | use crate::Error; |
2 | use crate::enums::{ContentType, ProtocolVersion}; |
3 | use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage}; |
4 | pub(crate) const MAX_FRAGMENT_LEN: usize = 16384; |
5 | pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2; |
6 | pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD; |
7 | |
8 | pub struct MessageFragmenter { |
9 | max_frag: usize, |
10 | } |
11 | |
12 | impl Default for MessageFragmenter { |
13 | fn default() -> Self { |
14 | Self { |
15 | max_frag: MAX_FRAGMENT_LEN, |
16 | } |
17 | } |
18 | } |
19 | |
20 | impl MessageFragmenter { |
21 | /// Take `msg` and fragment it into new messages with the same type and version. |
22 | /// |
23 | /// Each returned message size is no more than `max_frag`. |
24 | /// |
25 | /// Return an iterator across those messages. |
26 | /// |
27 | /// Payloads are borrowed from `msg`. |
28 | pub fn fragment_message<'a>( |
29 | &self, |
30 | msg: &'a PlainMessage, |
31 | ) -> impl Iterator<Item = OutboundPlainMessage<'a>> + 'a { |
32 | self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into()) |
33 | } |
34 | |
35 | /// Take `payload` and fragment it into new messages with given type and version. |
36 | /// |
37 | /// Each returned message size is no more than `max_frag`. |
38 | /// |
39 | /// Return an iterator across those messages. |
40 | /// |
41 | /// Payloads are borrowed from `payload`. |
42 | pub(crate) fn fragment_payload<'a>( |
43 | &self, |
44 | typ: ContentType, |
45 | version: ProtocolVersion, |
46 | payload: OutboundChunks<'a>, |
47 | ) -> impl ExactSizeIterator<Item = OutboundPlainMessage<'a>> { |
48 | Chunker::new(payload, self.max_frag).map(move |payload| OutboundPlainMessage { |
49 | typ, |
50 | version, |
51 | payload, |
52 | }) |
53 | } |
54 | |
55 | /// Set the maximum fragment size that will be produced. |
56 | /// |
57 | /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments |
58 | /// up to 10 bytes long. |
59 | /// |
60 | /// A `max_fragment_size` of `None` sets the highest allowable fragment size. |
61 | /// |
62 | /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389. |
63 | pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> { |
64 | self.max_frag = match max_fragment_size { |
65 | Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD, |
66 | None => MAX_FRAGMENT_LEN, |
67 | _ => return Err(Error::BadMaxFragmentSize), |
68 | }; |
69 | Ok(()) |
70 | } |
71 | } |
72 | |
73 | /// An iterator over borrowed fragments of a payload |
74 | struct Chunker<'a> { |
75 | payload: OutboundChunks<'a>, |
76 | limit: usize, |
77 | } |
78 | |
79 | impl<'a> Chunker<'a> { |
80 | fn new(payload: OutboundChunks<'a>, limit: usize) -> Self { |
81 | Self { payload, limit } |
82 | } |
83 | } |
84 | |
85 | impl<'a> Iterator for Chunker<'a> { |
86 | type Item = OutboundChunks<'a>; |
87 | |
88 | fn next(&mut self) -> Option<Self::Item> { |
89 | if self.payload.is_empty() { |
90 | return None; |
91 | } |
92 | |
93 | let (before: OutboundChunks<'_>, after: OutboundChunks<'_>) = self.payload.split_at(self.limit); |
94 | self.payload = after; |
95 | Some(before) |
96 | } |
97 | } |
98 | |
99 | impl ExactSizeIterator for Chunker<'_> { |
100 | fn len(&self) -> usize { |
101 | (self.payload.len() + self.limit - 1) / self.limit |
102 | } |
103 | } |
104 | |
105 | #[cfg (test)] |
106 | mod tests { |
107 | use std::prelude::v1::*; |
108 | use std::vec; |
109 | |
110 | use super::{MessageFragmenter, PACKET_OVERHEAD}; |
111 | use crate::enums::{ContentType, ProtocolVersion}; |
112 | use crate::msgs::base::Payload; |
113 | use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage}; |
114 | |
115 | fn msg_eq( |
116 | m: &OutboundPlainMessage<'_>, |
117 | total_len: usize, |
118 | typ: &ContentType, |
119 | version: &ProtocolVersion, |
120 | bytes: &[u8], |
121 | ) { |
122 | assert_eq!(&m.typ, typ); |
123 | assert_eq!(&m.version, version); |
124 | assert_eq!(m.payload.to_vec(), bytes); |
125 | |
126 | let buf = m.to_unencrypted_opaque().encode(); |
127 | |
128 | assert_eq!(total_len, buf.len()); |
129 | } |
130 | |
131 | #[test ] |
132 | fn smoke() { |
133 | let typ = ContentType::Handshake; |
134 | let version = ProtocolVersion::TLSv1_2; |
135 | let data: Vec<u8> = (1..70u8).collect(); |
136 | let m = PlainMessage { |
137 | typ, |
138 | version, |
139 | payload: Payload::new(data), |
140 | }; |
141 | |
142 | let mut frag = MessageFragmenter::default(); |
143 | frag.set_max_fragment_size(Some(32)) |
144 | .unwrap(); |
145 | let q = frag |
146 | .fragment_message(&m) |
147 | .collect::<Vec<_>>(); |
148 | assert_eq!(q.len(), 3); |
149 | msg_eq( |
150 | &q[0], |
151 | 32, |
152 | &typ, |
153 | &version, |
154 | &[ |
155 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, |
156 | 24, 25, 26, 27, |
157 | ], |
158 | ); |
159 | msg_eq( |
160 | &q[1], |
161 | 32, |
162 | &typ, |
163 | &version, |
164 | &[ |
165 | 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, |
166 | 49, 50, 51, 52, 53, 54, |
167 | ], |
168 | ); |
169 | msg_eq( |
170 | &q[2], |
171 | 20, |
172 | &typ, |
173 | &version, |
174 | &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69], |
175 | ); |
176 | } |
177 | |
178 | #[test ] |
179 | fn non_fragment() { |
180 | let m = PlainMessage { |
181 | typ: ContentType::Handshake, |
182 | version: ProtocolVersion::TLSv1_2, |
183 | payload: Payload::new(b" \x01\x02\x03\x04\x05\x06\x07\x08" .to_vec()), |
184 | }; |
185 | |
186 | let mut frag = MessageFragmenter::default(); |
187 | frag.set_max_fragment_size(Some(32)) |
188 | .unwrap(); |
189 | let q = frag |
190 | .fragment_message(&m) |
191 | .collect::<Vec<_>>(); |
192 | assert_eq!(q.len(), 1); |
193 | msg_eq( |
194 | &q[0], |
195 | PACKET_OVERHEAD + 8, |
196 | &ContentType::Handshake, |
197 | &ProtocolVersion::TLSv1_2, |
198 | b" \x01\x02\x03\x04\x05\x06\x07\x08" , |
199 | ); |
200 | } |
201 | |
202 | #[test ] |
203 | fn fragment_multiple_slices() { |
204 | let typ = ContentType::Handshake; |
205 | let version = ProtocolVersion::TLSv1_2; |
206 | let payload_owner: Vec<&[u8]> = vec![&[b'a' ; 8], &[b'b' ; 12], &[b'c' ; 32], &[b'd' ; 20]]; |
207 | let borrowed_payload = OutboundChunks::new(&payload_owner); |
208 | let mut frag = MessageFragmenter::default(); |
209 | frag.set_max_fragment_size(Some(37)) // 32 + packet overhead |
210 | .unwrap(); |
211 | |
212 | let fragments = frag |
213 | .fragment_payload(typ, version, borrowed_payload) |
214 | .collect::<Vec<_>>(); |
215 | assert_eq!(fragments.len(), 3); |
216 | msg_eq( |
217 | &fragments[0], |
218 | 37, |
219 | &typ, |
220 | &version, |
221 | b"aaaaaaaabbbbbbbbbbbbcccccccccccc" , |
222 | ); |
223 | msg_eq( |
224 | &fragments[1], |
225 | 37, |
226 | &typ, |
227 | &version, |
228 | b"ccccccccccccccccccccdddddddddddd" , |
229 | ); |
230 | msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd" ); |
231 | } |
232 | } |
233 | |