1use crate::crypto::cipher::{MessageDecrypter, MessageEncrypter};
2use crate::error::Error;
3use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
4
5#[cfg(feature = "logging")]
6use crate::log::trace;
7
8use alloc::boxed::Box;
9
10static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
11static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
12
13#[derive(PartialEq)]
14enum DirectionState {
15 /// No keying material.
16 Invalid,
17
18 /// Keying material present, but not yet in use.
19 Prepared,
20
21 /// Keying material in use.
22 Active,
23}
24
25/// Record layer that tracks decryption and encryption keys.
26pub struct RecordLayer {
27 message_encrypter: Box<dyn MessageEncrypter>,
28 message_decrypter: Box<dyn MessageDecrypter>,
29 write_seq: u64,
30 read_seq: u64,
31 has_decrypted: bool,
32 encrypt_state: DirectionState,
33 decrypt_state: DirectionState,
34
35 // Message encrypted with other keys may be encountered, so failures
36 // should be swallowed by the caller. This struct tracks the amount
37 // of message size this is allowed for.
38 trial_decryption_len: Option<usize>,
39}
40
41impl RecordLayer {
42 /// Create new record layer with no keys.
43 pub fn new() -> Self {
44 Self {
45 message_encrypter: <dyn MessageEncrypter>::invalid(),
46 message_decrypter: <dyn MessageDecrypter>::invalid(),
47 write_seq: 0,
48 read_seq: 0,
49 has_decrypted: false,
50 encrypt_state: DirectionState::Invalid,
51 decrypt_state: DirectionState::Invalid,
52 trial_decryption_len: None,
53 }
54 }
55
56 /// Decrypt a TLS message.
57 ///
58 /// `encr` is a decoded message allegedly received from the peer.
59 /// If it can be decrypted, its decryption is returned. Otherwise,
60 /// an error is returned.
61 pub(crate) fn decrypt_incoming(
62 &mut self,
63 encr: OpaqueMessage,
64 ) -> Result<Option<Decrypted>, Error> {
65 if self.decrypt_state != DirectionState::Active {
66 return Ok(Some(Decrypted {
67 want_close_before_decrypt: false,
68 plaintext: encr.into_plain_message(),
69 }));
70 }
71
72 // Set to `true` if the peer appears to getting close to encrypting
73 // too many messages with this key.
74 //
75 // Perhaps if we send an alert well before their counter wraps, a
76 // buggy peer won't make a terrible mistake here?
77 //
78 // Note that there's no reason to refuse to decrypt: the security
79 // failure has already happened.
80 let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
81
82 let encrypted_len = encr.payload().len();
83 match self
84 .message_decrypter
85 .decrypt(encr, self.read_seq)
86 {
87 Ok(plaintext) => {
88 self.read_seq += 1;
89 if !self.has_decrypted {
90 self.has_decrypted = true;
91 }
92 Ok(Some(Decrypted {
93 want_close_before_decrypt,
94 plaintext,
95 }))
96 }
97 Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
98 trace!("Dropping undecryptable message after aborted early_data");
99 Ok(None)
100 }
101 Err(err) => Err(err),
102 }
103 }
104
105 /// Encrypt a TLS message.
106 ///
107 /// `plain` is a TLS message we'd like to send. This function
108 /// panics if the requisite keying material hasn't been established yet.
109 pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
110 debug_assert!(self.encrypt_state == DirectionState::Active);
111 assert!(!self.encrypt_exhausted());
112 let seq = self.write_seq;
113 self.write_seq += 1;
114 self.message_encrypter
115 .encrypt(plain, seq)
116 .unwrap()
117 }
118
119 /// Prepare to use the given `MessageEncrypter` for future message encryption.
120 /// It is not used until you call `start_encrypting`.
121 pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
122 self.message_encrypter = cipher;
123 self.write_seq = 0;
124 self.encrypt_state = DirectionState::Prepared;
125 }
126
127 /// Prepare to use the given `MessageDecrypter` for future message decryption.
128 /// It is not used until you call `start_decrypting`.
129 pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
130 self.message_decrypter = cipher;
131 self.read_seq = 0;
132 self.decrypt_state = DirectionState::Prepared;
133 }
134
135 /// Start using the `MessageEncrypter` previously provided to the previous
136 /// call to `prepare_message_encrypter`.
137 pub(crate) fn start_encrypting(&mut self) {
138 debug_assert!(self.encrypt_state == DirectionState::Prepared);
139 self.encrypt_state = DirectionState::Active;
140 }
141
142 /// Start using the `MessageDecrypter` previously provided to the previous
143 /// call to `prepare_message_decrypter`.
144 pub(crate) fn start_decrypting(&mut self) {
145 debug_assert!(self.decrypt_state == DirectionState::Prepared);
146 self.decrypt_state = DirectionState::Active;
147 }
148
149 /// Set and start using the given `MessageEncrypter` for future outgoing
150 /// message encryption.
151 pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
152 self.prepare_message_encrypter(cipher);
153 self.start_encrypting();
154 }
155
156 /// Set and start using the given `MessageDecrypter` for future incoming
157 /// message decryption.
158 pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
159 self.prepare_message_decrypter(cipher);
160 self.start_decrypting();
161 self.trial_decryption_len = None;
162 }
163
164 /// Set and start using the given `MessageDecrypter` for future incoming
165 /// message decryption, and enable "trial decryption" mode for when TLS1.3
166 /// 0-RTT is attempted but rejected by the server.
167 pub(crate) fn set_message_decrypter_with_trial_decryption(
168 &mut self,
169 cipher: Box<dyn MessageDecrypter>,
170 max_length: usize,
171 ) {
172 self.prepare_message_decrypter(cipher);
173 self.start_decrypting();
174 self.trial_decryption_len = Some(max_length);
175 }
176
177 pub(crate) fn finish_trial_decryption(&mut self) {
178 self.trial_decryption_len = None;
179 }
180
181 /// Return true if we are getting close to encrypting too many
182 /// messages with our encryption key.
183 pub(crate) fn wants_close_before_encrypt(&self) -> bool {
184 self.write_seq == SEQ_SOFT_LIMIT
185 }
186
187 /// Return true if we outright refuse to do anything with the
188 /// encryption key.
189 pub(crate) fn encrypt_exhausted(&self) -> bool {
190 self.write_seq >= SEQ_HARD_LIMIT
191 }
192
193 pub(crate) fn is_encrypting(&self) -> bool {
194 self.encrypt_state == DirectionState::Active
195 }
196
197 /// Return true if we have ever decrypted a message. This is used in place
198 /// of checking the read_seq since that will be reset on key updates.
199 pub(crate) fn has_decrypted(&self) -> bool {
200 self.has_decrypted
201 }
202
203 pub(crate) fn write_seq(&self) -> u64 {
204 self.write_seq
205 }
206
207 pub(crate) fn read_seq(&self) -> u64 {
208 self.read_seq
209 }
210
211 fn doing_trial_decryption(&mut self, requested: usize) -> bool {
212 match self
213 .trial_decryption_len
214 .and_then(|value| value.checked_sub(requested))
215 {
216 Some(remaining) => {
217 self.trial_decryption_len = Some(remaining);
218 true
219 }
220 _ => false,
221 }
222 }
223}
224
225/// Result of decryption.
226#[derive(Debug)]
227pub(crate) struct Decrypted {
228 /// Whether the peer appears to be getting close to encrypting too many messages with this key.
229 pub(crate) want_close_before_decrypt: bool,
230 /// The decrypted message.
231 pub(crate) plaintext: PlainMessage,
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_has_decrypted() {
240 use crate::{ContentType, ProtocolVersion};
241
242 struct PassThroughDecrypter;
243 impl MessageDecrypter for PassThroughDecrypter {
244 fn decrypt(&mut self, m: OpaqueMessage, _: u64) -> Result<PlainMessage, Error> {
245 Ok(m.into_plain_message())
246 }
247 }
248
249 // A record layer starts out invalid, having never decrypted.
250 let mut record_layer = RecordLayer::new();
251 assert!(matches!(
252 record_layer.decrypt_state,
253 DirectionState::Invalid
254 ));
255 assert_eq!(record_layer.read_seq, 0);
256 assert!(!record_layer.has_decrypted());
257
258 // Preparing the record layer should update the decrypt state, but shouldn't affect whether it
259 // has decrypted.
260 record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
261 assert!(matches!(
262 record_layer.decrypt_state,
263 DirectionState::Prepared
264 ));
265 assert_eq!(record_layer.read_seq, 0);
266 assert!(!record_layer.has_decrypted());
267
268 // Starting decryption should update the decrypt state, but not affect whether it has decrypted.
269 record_layer.start_decrypting();
270 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
271 assert_eq!(record_layer.read_seq, 0);
272 assert!(!record_layer.has_decrypted());
273
274 // Decrypting a message should update the read_seq and track that we have now performed
275 // a decryption.
276 let msg = OpaqueMessage::new(
277 ContentType::Handshake,
278 ProtocolVersion::TLSv1_2,
279 vec![0xC0, 0xFF, 0xEE],
280 );
281 record_layer
282 .decrypt_incoming(msg)
283 .unwrap();
284 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
285 assert_eq!(record_layer.read_seq, 1);
286 assert!(record_layer.has_decrypted());
287
288 // Resetting the record layer message decrypter (as if a key update occurred) should reset
289 // the read_seq number, but not our knowledge of whether we have decrypted previously.
290 record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
291 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
292 assert_eq!(record_layer.read_seq, 0);
293 assert!(record_layer.has_decrypted());
294 }
295}
296