1 | use pki_types::ServerName; |
2 | |
3 | use crate::enums::SignatureScheme; |
4 | use crate::msgs::persist; |
5 | use crate::sync::Arc; |
6 | use crate::{NamedGroup, client, sign}; |
7 | |
8 | /// An implementer of `ClientSessionStore` which does nothing. |
9 | #[derive (Debug)] |
10 | pub(super) struct NoClientSessionStorage; |
11 | |
12 | impl client::ClientSessionStore for NoClientSessionStorage { |
13 | fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {} |
14 | |
15 | fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> { |
16 | None |
17 | } |
18 | |
19 | fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {} |
20 | |
21 | fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> { |
22 | None |
23 | } |
24 | |
25 | fn remove_tls12_session(&self, _: &ServerName<'_>) {} |
26 | |
27 | fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {} |
28 | |
29 | fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> { |
30 | None |
31 | } |
32 | } |
33 | |
34 | #[cfg (any(feature = "std" , feature = "hashbrown" ))] |
35 | mod cache { |
36 | use alloc::collections::VecDeque; |
37 | use core::fmt; |
38 | |
39 | use pki_types::ServerName; |
40 | |
41 | use crate::lock::Mutex; |
42 | use crate::msgs::persist; |
43 | use crate::{NamedGroup, limited_cache}; |
44 | |
45 | const MAX_TLS13_TICKETS_PER_SERVER: usize = 8; |
46 | |
47 | struct ServerData { |
48 | kx_hint: Option<NamedGroup>, |
49 | |
50 | // Zero or one TLS1.2 sessions. |
51 | #[cfg (feature = "tls12" )] |
52 | tls12: Option<persist::Tls12ClientSessionValue>, |
53 | |
54 | // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first. |
55 | tls13: VecDeque<persist::Tls13ClientSessionValue>, |
56 | } |
57 | |
58 | impl Default for ServerData { |
59 | fn default() -> Self { |
60 | Self { |
61 | kx_hint: None, |
62 | #[cfg (feature = "tls12" )] |
63 | tls12: None, |
64 | tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER), |
65 | } |
66 | } |
67 | } |
68 | |
69 | /// An implementer of `ClientSessionStore` that stores everything |
70 | /// in memory. |
71 | /// |
72 | /// It enforces a limit on the number of entries to bound memory usage. |
73 | pub struct ClientSessionMemoryCache { |
74 | servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>, |
75 | } |
76 | |
77 | impl ClientSessionMemoryCache { |
78 | /// Make a new ClientSessionMemoryCache. `size` is the |
79 | /// maximum number of stored sessions. |
80 | #[cfg (feature = "std" )] |
81 | pub fn new(size: usize) -> Self { |
82 | let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) |
83 | / MAX_TLS13_TICKETS_PER_SERVER; |
84 | Self { |
85 | servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)), |
86 | } |
87 | } |
88 | |
89 | /// Make a new ClientSessionMemoryCache. `size` is the |
90 | /// maximum number of stored sessions. |
91 | #[cfg (not(feature = "std" ))] |
92 | pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self { |
93 | let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) |
94 | / MAX_TLS13_TICKETS_PER_SERVER; |
95 | Self { |
96 | servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)), |
97 | } |
98 | } |
99 | } |
100 | |
101 | impl super::client::ClientSessionStore for ClientSessionMemoryCache { |
102 | fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) { |
103 | self.servers |
104 | .lock() |
105 | .unwrap() |
106 | .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group)); |
107 | } |
108 | |
109 | fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> { |
110 | self.servers |
111 | .lock() |
112 | .unwrap() |
113 | .get(server_name) |
114 | .and_then(|sd| sd.kx_hint) |
115 | } |
116 | |
117 | fn set_tls12_session( |
118 | &self, |
119 | _server_name: ServerName<'static>, |
120 | _value: persist::Tls12ClientSessionValue, |
121 | ) { |
122 | #[cfg (feature = "tls12" )] |
123 | self.servers |
124 | .lock() |
125 | .unwrap() |
126 | .get_or_insert_default_and_edit(_server_name.clone(), |data| { |
127 | data.tls12 = Some(_value) |
128 | }); |
129 | } |
130 | |
131 | fn tls12_session( |
132 | &self, |
133 | _server_name: &ServerName<'_>, |
134 | ) -> Option<persist::Tls12ClientSessionValue> { |
135 | #[cfg (not(feature = "tls12" ))] |
136 | return None; |
137 | |
138 | #[cfg (feature = "tls12" )] |
139 | self.servers |
140 | .lock() |
141 | .unwrap() |
142 | .get(_server_name) |
143 | .and_then(|sd| sd.tls12.as_ref().cloned()) |
144 | } |
145 | |
146 | fn remove_tls12_session(&self, _server_name: &ServerName<'static>) { |
147 | #[cfg (feature = "tls12" )] |
148 | self.servers |
149 | .lock() |
150 | .unwrap() |
151 | .get_mut(_server_name) |
152 | .and_then(|data| data.tls12.take()); |
153 | } |
154 | |
155 | fn insert_tls13_ticket( |
156 | &self, |
157 | server_name: ServerName<'static>, |
158 | value: persist::Tls13ClientSessionValue, |
159 | ) { |
160 | self.servers |
161 | .lock() |
162 | .unwrap() |
163 | .get_or_insert_default_and_edit(server_name.clone(), |data| { |
164 | if data.tls13.len() == data.tls13.capacity() { |
165 | data.tls13.pop_front(); |
166 | } |
167 | data.tls13.push_back(value); |
168 | }); |
169 | } |
170 | |
171 | fn take_tls13_ticket( |
172 | &self, |
173 | server_name: &ServerName<'static>, |
174 | ) -> Option<persist::Tls13ClientSessionValue> { |
175 | self.servers |
176 | .lock() |
177 | .unwrap() |
178 | .get_mut(server_name) |
179 | .and_then(|data| data.tls13.pop_back()) |
180 | } |
181 | } |
182 | |
183 | impl fmt::Debug for ClientSessionMemoryCache { |
184 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
185 | // Note: we omit self.servers as it may contain sensitive data. |
186 | f.debug_struct("ClientSessionMemoryCache" ) |
187 | .finish() |
188 | } |
189 | } |
190 | } |
191 | |
192 | #[cfg (any(feature = "std" , feature = "hashbrown" ))] |
193 | pub use cache::ClientSessionMemoryCache; |
194 | |
195 | #[derive (Debug)] |
196 | pub(super) struct FailResolveClientCert {} |
197 | |
198 | impl client::ResolvesClientCert for FailResolveClientCert { |
199 | fn resolve( |
200 | &self, |
201 | _root_hint_subjects: &[&[u8]], |
202 | _sigschemes: &[SignatureScheme], |
203 | ) -> Option<Arc<sign::CertifiedKey>> { |
204 | None |
205 | } |
206 | |
207 | fn has_certs(&self) -> bool { |
208 | false |
209 | } |
210 | } |
211 | |
212 | /// An exemplar `ResolvesClientCert` implementation that always resolves to a single |
213 | /// [RFC 7250] raw public key. |
214 | /// |
215 | /// [RFC 7250]: https://tools.ietf.org/html/rfc7250 |
216 | #[derive (Clone, Debug)] |
217 | pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>); |
218 | impl AlwaysResolvesClientRawPublicKeys { |
219 | /// Create a new `AlwaysResolvesClientRawPublicKeys` instance. |
220 | pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self { |
221 | Self(certified_key) |
222 | } |
223 | } |
224 | |
225 | impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys { |
226 | fn resolve( |
227 | &self, |
228 | _root_hint_subjects: &[&[u8]], |
229 | _sigschemes: &[SignatureScheme], |
230 | ) -> Option<Arc<sign::CertifiedKey>> { |
231 | Some(Arc::clone(&self.0)) |
232 | } |
233 | |
234 | fn only_raw_public_keys(&self) -> bool { |
235 | true |
236 | } |
237 | |
238 | /// Returns true if the resolver is ready to present an identity. |
239 | /// |
240 | /// Even though the function is called `has_certs`, it returns true |
241 | /// although only an RPK (Raw Public Key) is available, not an actual certificate. |
242 | fn has_certs(&self) -> bool { |
243 | true |
244 | } |
245 | } |
246 | |
247 | #[cfg (test)] |
248 | #[macro_rules_attribute::apply(test_for_each_provider)] |
249 | mod tests { |
250 | use std::prelude::v1::*; |
251 | |
252 | use pki_types::{ServerName, UnixTime}; |
253 | |
254 | use super::NoClientSessionStorage; |
255 | use super::provider::cipher_suite; |
256 | use crate::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; |
257 | use crate::client::{ClientSessionStore, ResolvesClientCert}; |
258 | use crate::msgs::base::PayloadU16; |
259 | use crate::msgs::enums::NamedGroup; |
260 | use crate::msgs::handshake::CertificateChain; |
261 | #[cfg (feature = "tls12" )] |
262 | use crate::msgs::handshake::SessionId; |
263 | use crate::msgs::persist::Tls13ClientSessionValue; |
264 | use crate::pki_types::CertificateDer; |
265 | use crate::suites::SupportedCipherSuite; |
266 | use crate::sync::Arc; |
267 | use crate::{DigitallySignedStruct, Error, SignatureScheme, sign}; |
268 | |
269 | #[test ] |
270 | fn test_noclientsessionstorage_does_nothing() { |
271 | let c = NoClientSessionStorage {}; |
272 | let name = ServerName::try_from("example.com" ).unwrap(); |
273 | let now = UnixTime::now(); |
274 | let server_cert_verifier: Arc<dyn ServerCertVerifier> = Arc::new(DummyServerCertVerifier); |
275 | let resolves_client_cert: Arc<dyn ResolvesClientCert> = Arc::new(DummyResolvesClientCert); |
276 | |
277 | c.set_kx_hint(name.clone(), NamedGroup::X25519); |
278 | assert_eq!(None, c.kx_hint(&name)); |
279 | |
280 | #[cfg (feature = "tls12" )] |
281 | { |
282 | use crate::msgs::persist::Tls12ClientSessionValue; |
283 | let SupportedCipherSuite::Tls12(tls12_suite) = |
284 | cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 |
285 | else { |
286 | unreachable!() |
287 | }; |
288 | |
289 | c.set_tls12_session( |
290 | name.clone(), |
291 | Tls12ClientSessionValue::new( |
292 | tls12_suite, |
293 | SessionId::empty(), |
294 | Arc::new(PayloadU16::empty()), |
295 | &[], |
296 | CertificateChain::default(), |
297 | &server_cert_verifier, |
298 | &resolves_client_cert, |
299 | now, |
300 | 0, |
301 | true, |
302 | ), |
303 | ); |
304 | assert!(c.tls12_session(&name).is_none()); |
305 | c.remove_tls12_session(&name); |
306 | } |
307 | |
308 | let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384 |
309 | else { |
310 | unreachable!(); |
311 | }; |
312 | c.insert_tls13_ticket( |
313 | name.clone(), |
314 | Tls13ClientSessionValue::new( |
315 | tls13_suite, |
316 | Arc::new(PayloadU16::empty()), |
317 | &[], |
318 | CertificateChain::default(), |
319 | &server_cert_verifier, |
320 | &resolves_client_cert, |
321 | now, |
322 | 0, |
323 | 0, |
324 | 0, |
325 | ), |
326 | ); |
327 | assert!(c.take_tls13_ticket(&name).is_none()); |
328 | } |
329 | |
330 | #[derive (Debug)] |
331 | struct DummyServerCertVerifier; |
332 | |
333 | impl ServerCertVerifier for DummyServerCertVerifier { |
334 | #[cfg_attr (coverage_nightly, coverage(off))] |
335 | fn verify_server_cert( |
336 | &self, |
337 | _end_entity: &CertificateDer<'_>, |
338 | _intermediates: &[CertificateDer<'_>], |
339 | _server_name: &ServerName<'_>, |
340 | _ocsp_response: &[u8], |
341 | _now: UnixTime, |
342 | ) -> Result<ServerCertVerified, Error> { |
343 | unreachable!() |
344 | } |
345 | |
346 | #[cfg_attr (coverage_nightly, coverage(off))] |
347 | fn verify_tls12_signature( |
348 | &self, |
349 | _message: &[u8], |
350 | _cert: &CertificateDer<'_>, |
351 | _dss: &DigitallySignedStruct, |
352 | ) -> Result<HandshakeSignatureValid, Error> { |
353 | unreachable!() |
354 | } |
355 | |
356 | #[cfg_attr (coverage_nightly, coverage(off))] |
357 | fn verify_tls13_signature( |
358 | &self, |
359 | _message: &[u8], |
360 | _cert: &CertificateDer<'_>, |
361 | _dss: &DigitallySignedStruct, |
362 | ) -> Result<HandshakeSignatureValid, Error> { |
363 | unreachable!() |
364 | } |
365 | |
366 | #[cfg_attr (coverage_nightly, coverage(off))] |
367 | fn supported_verify_schemes(&self) -> Vec<SignatureScheme> { |
368 | unreachable!() |
369 | } |
370 | } |
371 | |
372 | #[derive (Debug)] |
373 | struct DummyResolvesClientCert; |
374 | |
375 | impl ResolvesClientCert for DummyResolvesClientCert { |
376 | #[cfg_attr (coverage_nightly, coverage(off))] |
377 | fn resolve( |
378 | &self, |
379 | _root_hint_subjects: &[&[u8]], |
380 | _sigschemes: &[SignatureScheme], |
381 | ) -> Option<Arc<sign::CertifiedKey>> { |
382 | unreachable!() |
383 | } |
384 | |
385 | #[cfg_attr (coverage_nightly, coverage(off))] |
386 | fn has_certs(&self) -> bool { |
387 | unreachable!() |
388 | } |
389 | } |
390 | } |
391 | |