1 | use crate::client; |
2 | use crate::enums::SignatureScheme; |
3 | use crate::error::Error; |
4 | use crate::limited_cache; |
5 | use crate::msgs::handshake::CertificateChain; |
6 | use crate::msgs::persist; |
7 | use crate::sign; |
8 | use crate::NamedGroup; |
9 | |
10 | use pki_types::ServerName; |
11 | |
12 | use alloc::collections::VecDeque; |
13 | use alloc::sync::Arc; |
14 | use core::fmt; |
15 | use std::sync::Mutex; |
16 | |
17 | /// An implementer of `ClientSessionStore` which does nothing. |
18 | #[derive (Debug)] |
19 | pub(super) struct NoClientSessionStorage; |
20 | |
21 | impl client::ClientSessionStore for NoClientSessionStorage { |
22 | fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {} |
23 | |
24 | fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> { |
25 | None |
26 | } |
27 | |
28 | fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {} |
29 | |
30 | fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> { |
31 | None |
32 | } |
33 | |
34 | fn remove_tls12_session(&self, _: &ServerName<'_>) {} |
35 | |
36 | fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {} |
37 | |
38 | fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> { |
39 | None |
40 | } |
41 | } |
42 | |
43 | const MAX_TLS13_TICKETS_PER_SERVER: usize = 8; |
44 | |
45 | struct ServerData { |
46 | kx_hint: Option<NamedGroup>, |
47 | |
48 | // Zero or one TLS1.2 sessions. |
49 | #[cfg (feature = "tls12" )] |
50 | tls12: Option<persist::Tls12ClientSessionValue>, |
51 | |
52 | // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first. |
53 | tls13: VecDeque<persist::Tls13ClientSessionValue>, |
54 | } |
55 | |
56 | impl Default for ServerData { |
57 | fn default() -> Self { |
58 | Self { |
59 | kx_hint: None, |
60 | #[cfg (feature = "tls12" )] |
61 | tls12: None, |
62 | tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER), |
63 | } |
64 | } |
65 | } |
66 | |
67 | /// An implementer of `ClientSessionStore` that stores everything |
68 | /// in memory. |
69 | /// |
70 | /// It enforces a limit on the number of entries to bound memory usage. |
71 | pub struct ClientSessionMemoryCache { |
72 | servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>, |
73 | } |
74 | |
75 | impl ClientSessionMemoryCache { |
76 | /// Make a new ClientSessionMemoryCache. `size` is the |
77 | /// maximum number of stored sessions. |
78 | pub fn new(size: usize) -> Self { |
79 | let max_servers: usize = |
80 | size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER; |
81 | Self { |
82 | servers: Mutex::new(limited_cache::LimitedCache::new(capacity_order_of_magnitude:max_servers)), |
83 | } |
84 | } |
85 | } |
86 | |
87 | impl client::ClientSessionStore for ClientSessionMemoryCache { |
88 | fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) { |
89 | self.servers |
90 | .lock() |
91 | .unwrap() |
92 | .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group)); |
93 | } |
94 | |
95 | fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> { |
96 | self.servers |
97 | .lock() |
98 | .unwrap() |
99 | .get(server_name) |
100 | .and_then(|sd| sd.kx_hint) |
101 | } |
102 | |
103 | fn set_tls12_session( |
104 | &self, |
105 | _server_name: ServerName<'static>, |
106 | _value: persist::Tls12ClientSessionValue, |
107 | ) { |
108 | #[cfg (feature = "tls12" )] |
109 | self.servers |
110 | .lock() |
111 | .unwrap() |
112 | .get_or_insert_default_and_edit(_server_name.clone(), |data| data.tls12 = Some(_value)); |
113 | } |
114 | |
115 | fn tls12_session( |
116 | &self, |
117 | _server_name: &ServerName<'_>, |
118 | ) -> Option<persist::Tls12ClientSessionValue> { |
119 | #[cfg (not(feature = "tls12" ))] |
120 | return None; |
121 | |
122 | #[cfg (feature = "tls12" )] |
123 | self.servers |
124 | .lock() |
125 | .unwrap() |
126 | .get(_server_name) |
127 | .and_then(|sd| sd.tls12.as_ref().cloned()) |
128 | } |
129 | |
130 | fn remove_tls12_session(&self, _server_name: &ServerName<'static>) { |
131 | #[cfg (feature = "tls12" )] |
132 | self.servers |
133 | .lock() |
134 | .unwrap() |
135 | .get_mut(_server_name) |
136 | .and_then(|data| data.tls12.take()); |
137 | } |
138 | |
139 | fn insert_tls13_ticket( |
140 | &self, |
141 | server_name: ServerName<'static>, |
142 | value: persist::Tls13ClientSessionValue, |
143 | ) { |
144 | self.servers |
145 | .lock() |
146 | .unwrap() |
147 | .get_or_insert_default_and_edit(server_name.clone(), |data| { |
148 | if data.tls13.len() == data.tls13.capacity() { |
149 | data.tls13.pop_front(); |
150 | } |
151 | data.tls13.push_back(value); |
152 | }); |
153 | } |
154 | |
155 | fn take_tls13_ticket( |
156 | &self, |
157 | server_name: &ServerName<'static>, |
158 | ) -> Option<persist::Tls13ClientSessionValue> { |
159 | self.servers |
160 | .lock() |
161 | .unwrap() |
162 | .get_mut(server_name) |
163 | .and_then(|data| data.tls13.pop_back()) |
164 | } |
165 | } |
166 | |
167 | impl fmt::Debug for ClientSessionMemoryCache { |
168 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
169 | // Note: we omit self.servers as it may contain sensitive data. |
170 | fDebugStruct<'_, '_>.debug_struct(name:"ClientSessionMemoryCache" ) |
171 | .finish() |
172 | } |
173 | } |
174 | |
175 | #[derive (Debug)] |
176 | pub(super) struct FailResolveClientCert {} |
177 | |
178 | impl client::ResolvesClientCert for FailResolveClientCert { |
179 | fn resolve( |
180 | &self, |
181 | _root_hint_subjects: &[&[u8]], |
182 | _sigschemes: &[SignatureScheme], |
183 | ) -> Option<Arc<sign::CertifiedKey>> { |
184 | None |
185 | } |
186 | |
187 | fn has_certs(&self) -> bool { |
188 | false |
189 | } |
190 | } |
191 | |
192 | #[derive (Debug)] |
193 | pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>); |
194 | |
195 | impl AlwaysResolvesClientCert { |
196 | pub(super) fn new( |
197 | private_key: Arc<dyn sign::SigningKey>, |
198 | chain: CertificateChain, |
199 | ) -> Result<Self, Error> { |
200 | Ok(Self(Arc::new(data:sign::CertifiedKey::new( |
201 | cert:chain.0, |
202 | private_key, |
203 | )))) |
204 | } |
205 | } |
206 | |
207 | impl client::ResolvesClientCert for AlwaysResolvesClientCert { |
208 | fn resolve( |
209 | &self, |
210 | _root_hint_subjects: &[&[u8]], |
211 | _sigschemes: &[SignatureScheme], |
212 | ) -> Option<Arc<sign::CertifiedKey>> { |
213 | Some(Arc::clone(&self.0)) |
214 | } |
215 | |
216 | fn has_certs(&self) -> bool { |
217 | true |
218 | } |
219 | } |
220 | |
221 | #[cfg (all(test, any(feature = "ring" , feature = "aws_lc_rs" )))] |
222 | mod tests { |
223 | use super::NoClientSessionStorage; |
224 | use crate::client::ClientSessionStore; |
225 | use crate::msgs::enums::NamedGroup; |
226 | use crate::msgs::handshake::CertificateChain; |
227 | #[cfg (feature = "tls12" )] |
228 | use crate::msgs::handshake::SessionId; |
229 | use crate::msgs::persist::Tls13ClientSessionValue; |
230 | use crate::suites::SupportedCipherSuite; |
231 | use crate::test_provider::cipher_suite; |
232 | |
233 | use pki_types::{ServerName, UnixTime}; |
234 | |
235 | #[test ] |
236 | fn test_noclientsessionstorage_does_nothing() { |
237 | let c = NoClientSessionStorage {}; |
238 | let name = ServerName::try_from("example.com" ).unwrap(); |
239 | let now = UnixTime::now(); |
240 | |
241 | c.set_kx_hint(name.clone(), NamedGroup::X25519); |
242 | assert_eq!(None, c.kx_hint(&name)); |
243 | |
244 | #[cfg (feature = "tls12" )] |
245 | { |
246 | use crate::msgs::persist::Tls12ClientSessionValue; |
247 | let SupportedCipherSuite::Tls12(tls12_suite) = |
248 | cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 |
249 | else { |
250 | unreachable!() |
251 | }; |
252 | |
253 | c.set_tls12_session( |
254 | name.clone(), |
255 | Tls12ClientSessionValue::new( |
256 | tls12_suite, |
257 | SessionId::empty(), |
258 | Vec::new(), |
259 | &[], |
260 | CertificateChain::default(), |
261 | now, |
262 | 0, |
263 | true, |
264 | ), |
265 | ); |
266 | assert!(c.tls12_session(&name).is_none()); |
267 | c.remove_tls12_session(&name); |
268 | } |
269 | |
270 | #[cfg_attr (not(feature = "tls12" ), allow(clippy::infallible_destructuring_match))] |
271 | let tls13_suite = match cipher_suite::TLS13_AES_256_GCM_SHA384 { |
272 | SupportedCipherSuite::Tls13(inner) => inner, |
273 | #[cfg (feature = "tls12" )] |
274 | _ => unreachable!(), |
275 | }; |
276 | c.insert_tls13_ticket( |
277 | name.clone(), |
278 | Tls13ClientSessionValue::new( |
279 | tls13_suite, |
280 | Vec::new(), |
281 | &[], |
282 | CertificateChain::default(), |
283 | now, |
284 | 0, |
285 | 0, |
286 | 0, |
287 | ), |
288 | ); |
289 | assert!(c.take_tls13_ticket(&name).is_none()); |
290 | } |
291 | } |
292 | |