1use crate::enums::ContentType;
2use crate::enums::ProtocolVersion;
3use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
4use crate::Error;
5pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
6pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
7pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
8
9pub struct MessageFragmenter {
10 max_frag: usize,
11}
12
13impl Default for MessageFragmenter {
14 fn default() -> Self {
15 Self {
16 max_frag: MAX_FRAGMENT_LEN,
17 }
18 }
19}
20
21impl MessageFragmenter {
22 /// Take the Message `msg` and re-fragment it into new
23 /// messages whose fragment is no more than max_frag.
24 /// Return an iterator across those messages.
25 /// Payloads are borrowed.
26 pub fn fragment_message<'a>(
27 &self,
28 msg: &'a PlainMessage,
29 ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
30 self.fragment_slice(msg.typ, msg.version, &msg.payload.0)
31 }
32
33 /// Enqueue borrowed fragments of (version, typ, payload) which
34 /// are no longer than max_frag onto the `out` deque.
35 pub(crate) fn fragment_slice<'a>(
36 &self,
37 typ: ContentType,
38 version: ProtocolVersion,
39 payload: &'a [u8],
40 ) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
41 payload
42 .chunks(self.max_frag)
43 .map(move |c| BorrowedPlainMessage {
44 typ,
45 version,
46 payload: c,
47 })
48 }
49
50 /// Set the maximum fragment size that will be produced.
51 ///
52 /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments
53 /// up to 10 bytes long.
54 ///
55 /// A `max_fragment_size` of `None` sets the highest allowable fragment size.
56 ///
57 /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389.
58 pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> {
59 self.max_frag = match max_fragment_size {
60 Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
61 None => MAX_FRAGMENT_LEN,
62 _ => return Err(Error::BadMaxFragmentSize),
63 };
64 Ok(())
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::{MessageFragmenter, PACKET_OVERHEAD};
71 use crate::enums::ContentType;
72 use crate::enums::ProtocolVersion;
73 use crate::msgs::base::Payload;
74 use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
75
76 fn msg_eq(
77 m: &BorrowedPlainMessage,
78 total_len: usize,
79 typ: &ContentType,
80 version: &ProtocolVersion,
81 bytes: &[u8],
82 ) {
83 assert_eq!(&m.typ, typ);
84 assert_eq!(&m.version, version);
85 assert_eq!(m.payload, bytes);
86
87 let buf = m.to_unencrypted_opaque().encode();
88
89 assert_eq!(total_len, buf.len());
90 }
91
92 #[test]
93 fn smoke() {
94 let typ = ContentType::Handshake;
95 let version = ProtocolVersion::TLSv1_2;
96 let data: Vec<u8> = (1..70u8).collect();
97 let m = PlainMessage {
98 typ,
99 version,
100 payload: Payload::new(data),
101 };
102
103 let mut frag = MessageFragmenter::default();
104 frag.set_max_fragment_size(Some(32))
105 .unwrap();
106 let q = frag
107 .fragment_message(&m)
108 .collect::<Vec<_>>();
109 assert_eq!(q.len(), 3);
110 msg_eq(
111 &q[0],
112 32,
113 &typ,
114 &version,
115 &[
116 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
117 24, 25, 26, 27,
118 ],
119 );
120 msg_eq(
121 &q[1],
122 32,
123 &typ,
124 &version,
125 &[
126 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
127 49, 50, 51, 52, 53, 54,
128 ],
129 );
130 msg_eq(
131 &q[2],
132 20,
133 &typ,
134 &version,
135 &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
136 );
137 }
138
139 #[test]
140 fn non_fragment() {
141 let m = PlainMessage {
142 typ: ContentType::Handshake,
143 version: ProtocolVersion::TLSv1_2,
144 payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
145 };
146
147 let mut frag = MessageFragmenter::default();
148 frag.set_max_fragment_size(Some(32))
149 .unwrap();
150 let q = frag
151 .fragment_message(&m)
152 .collect::<Vec<_>>();
153 assert_eq!(q.len(), 1);
154 msg_eq(
155 &q[0],
156 PACKET_OVERHEAD + 8,
157 &ContentType::Handshake,
158 &ProtocolVersion::TLSv1_2,
159 b"\x01\x02\x03\x04\x05\x06\x07\x08",
160 );
161 }
162}
163