1use crate::client;
2use crate::enums::SignatureScheme;
3use crate::error::Error;
4use crate::limited_cache;
5use crate::msgs::handshake::CertificateChain;
6use crate::msgs::persist;
7use crate::sign;
8use crate::NamedGroup;
9
10use pki_types::ServerName;
11
12use alloc::collections::VecDeque;
13use alloc::sync::Arc;
14use core::fmt;
15use std::sync::Mutex;
16
17/// An implementer of `ClientSessionStore` which does nothing.
18#[derive(Debug)]
19pub(super) struct NoClientSessionStorage;
20
21impl client::ClientSessionStore for NoClientSessionStorage {
22 fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
23
24 fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
25 None
26 }
27
28 fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
29
30 fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
31 None
32 }
33
34 fn remove_tls12_session(&self, _: &ServerName<'_>) {}
35
36 fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
37
38 fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
39 None
40 }
41}
42
43const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
44
45struct ServerData {
46 kx_hint: Option<NamedGroup>,
47
48 // Zero or one TLS1.2 sessions.
49 #[cfg(feature = "tls12")]
50 tls12: Option<persist::Tls12ClientSessionValue>,
51
52 // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
53 tls13: VecDeque<persist::Tls13ClientSessionValue>,
54}
55
56impl Default for ServerData {
57 fn default() -> Self {
58 Self {
59 kx_hint: None,
60 #[cfg(feature = "tls12")]
61 tls12: None,
62 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
63 }
64 }
65}
66
67/// An implementer of `ClientSessionStore` that stores everything
68/// in memory.
69///
70/// It enforces a limit on the number of entries to bound memory usage.
71pub struct ClientSessionMemoryCache {
72 servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
73}
74
75impl ClientSessionMemoryCache {
76 /// Make a new ClientSessionMemoryCache. `size` is the
77 /// maximum number of stored sessions.
78 pub fn new(size: usize) -> Self {
79 let max_servers: usize =
80 size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER;
81 Self {
82 servers: Mutex::new(limited_cache::LimitedCache::new(capacity_order_of_magnitude:max_servers)),
83 }
84 }
85}
86
87impl client::ClientSessionStore for ClientSessionMemoryCache {
88 fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
89 self.servers
90 .lock()
91 .unwrap()
92 .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
93 }
94
95 fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
96 self.servers
97 .lock()
98 .unwrap()
99 .get(server_name)
100 .and_then(|sd| sd.kx_hint)
101 }
102
103 fn set_tls12_session(
104 &self,
105 _server_name: ServerName<'static>,
106 _value: persist::Tls12ClientSessionValue,
107 ) {
108 #[cfg(feature = "tls12")]
109 self.servers
110 .lock()
111 .unwrap()
112 .get_or_insert_default_and_edit(_server_name.clone(), |data| data.tls12 = Some(_value));
113 }
114
115 fn tls12_session(
116 &self,
117 _server_name: &ServerName<'_>,
118 ) -> Option<persist::Tls12ClientSessionValue> {
119 #[cfg(not(feature = "tls12"))]
120 return None;
121
122 #[cfg(feature = "tls12")]
123 self.servers
124 .lock()
125 .unwrap()
126 .get(_server_name)
127 .and_then(|sd| sd.tls12.as_ref().cloned())
128 }
129
130 fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
131 #[cfg(feature = "tls12")]
132 self.servers
133 .lock()
134 .unwrap()
135 .get_mut(_server_name)
136 .and_then(|data| data.tls12.take());
137 }
138
139 fn insert_tls13_ticket(
140 &self,
141 server_name: ServerName<'static>,
142 value: persist::Tls13ClientSessionValue,
143 ) {
144 self.servers
145 .lock()
146 .unwrap()
147 .get_or_insert_default_and_edit(server_name.clone(), |data| {
148 if data.tls13.len() == data.tls13.capacity() {
149 data.tls13.pop_front();
150 }
151 data.tls13.push_back(value);
152 });
153 }
154
155 fn take_tls13_ticket(
156 &self,
157 server_name: &ServerName<'static>,
158 ) -> Option<persist::Tls13ClientSessionValue> {
159 self.servers
160 .lock()
161 .unwrap()
162 .get_mut(server_name)
163 .and_then(|data| data.tls13.pop_back())
164 }
165}
166
167impl fmt::Debug for ClientSessionMemoryCache {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 // Note: we omit self.servers as it may contain sensitive data.
170 fDebugStruct<'_, '_>.debug_struct(name:"ClientSessionMemoryCache")
171 .finish()
172 }
173}
174
175#[derive(Debug)]
176pub(super) struct FailResolveClientCert {}
177
178impl client::ResolvesClientCert for FailResolveClientCert {
179 fn resolve(
180 &self,
181 _root_hint_subjects: &[&[u8]],
182 _sigschemes: &[SignatureScheme],
183 ) -> Option<Arc<sign::CertifiedKey>> {
184 None
185 }
186
187 fn has_certs(&self) -> bool {
188 false
189 }
190}
191
192#[derive(Debug)]
193pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
194
195impl AlwaysResolvesClientCert {
196 pub(super) fn new(
197 private_key: Arc<dyn sign::SigningKey>,
198 chain: CertificateChain,
199 ) -> Result<Self, Error> {
200 Ok(Self(Arc::new(data:sign::CertifiedKey::new(
201 cert:chain.0,
202 private_key,
203 ))))
204 }
205}
206
207impl client::ResolvesClientCert for AlwaysResolvesClientCert {
208 fn resolve(
209 &self,
210 _root_hint_subjects: &[&[u8]],
211 _sigschemes: &[SignatureScheme],
212 ) -> Option<Arc<sign::CertifiedKey>> {
213 Some(Arc::clone(&self.0))
214 }
215
216 fn has_certs(&self) -> bool {
217 true
218 }
219}
220
221#[cfg(all(test, any(feature = "ring", feature = "aws_lc_rs")))]
222mod tests {
223 use super::NoClientSessionStorage;
224 use crate::client::ClientSessionStore;
225 use crate::msgs::enums::NamedGroup;
226 use crate::msgs::handshake::CertificateChain;
227 #[cfg(feature = "tls12")]
228 use crate::msgs::handshake::SessionId;
229 use crate::msgs::persist::Tls13ClientSessionValue;
230 use crate::suites::SupportedCipherSuite;
231 use crate::test_provider::cipher_suite;
232
233 use pki_types::{ServerName, UnixTime};
234
235 #[test]
236 fn test_noclientsessionstorage_does_nothing() {
237 let c = NoClientSessionStorage {};
238 let name = ServerName::try_from("example.com").unwrap();
239 let now = UnixTime::now();
240
241 c.set_kx_hint(name.clone(), NamedGroup::X25519);
242 assert_eq!(None, c.kx_hint(&name));
243
244 #[cfg(feature = "tls12")]
245 {
246 use crate::msgs::persist::Tls12ClientSessionValue;
247 let SupportedCipherSuite::Tls12(tls12_suite) =
248 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
249 else {
250 unreachable!()
251 };
252
253 c.set_tls12_session(
254 name.clone(),
255 Tls12ClientSessionValue::new(
256 tls12_suite,
257 SessionId::empty(),
258 Vec::new(),
259 &[],
260 CertificateChain::default(),
261 now,
262 0,
263 true,
264 ),
265 );
266 assert!(c.tls12_session(&name).is_none());
267 c.remove_tls12_session(&name);
268 }
269
270 #[cfg_attr(not(feature = "tls12"), allow(clippy::infallible_destructuring_match))]
271 let tls13_suite = match cipher_suite::TLS13_AES_256_GCM_SHA384 {
272 SupportedCipherSuite::Tls13(inner) => inner,
273 #[cfg(feature = "tls12")]
274 _ => unreachable!(),
275 };
276 c.insert_tls13_ticket(
277 name.clone(),
278 Tls13ClientSessionValue::new(
279 tls13_suite,
280 Vec::new(),
281 &[],
282 CertificateChain::default(),
283 now,
284 0,
285 0,
286 0,
287 ),
288 );
289 assert!(c.take_tls13_ticket(&name).is_none());
290 }
291}
292