1 | use alloc::collections::BTreeSet; |
2 | #[cfg (feature = "logging" )] |
3 | use alloc::string::String; |
4 | use alloc::vec; |
5 | use alloc::vec::Vec; |
6 | use core::ops::Deref; |
7 | use core::{fmt, iter}; |
8 | |
9 | use pki_types::{CertificateDer, DnsName}; |
10 | |
11 | #[cfg (feature = "tls12" )] |
12 | use crate::crypto::ActiveKeyExchange; |
13 | use crate::crypto::SecureRandom; |
14 | use crate::enums::{ |
15 | CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType, |
16 | ProtocolVersion, SignatureScheme, |
17 | }; |
18 | use crate::error::InvalidMessage; |
19 | #[cfg (feature = "tls12" )] |
20 | use crate::ffdhe_groups::FfdheGroup; |
21 | use crate::log::warn; |
22 | use crate::msgs::base::{Payload, PayloadU8, PayloadU16, PayloadU24}; |
23 | use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement}; |
24 | use crate::msgs::enums::{ |
25 | CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType, |
26 | ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, |
27 | NamedGroup, PSKKeyExchangeMode, ServerNameType, |
28 | }; |
29 | use crate::rand; |
30 | use crate::sync::Arc; |
31 | use crate::verify::DigitallySignedStruct; |
32 | use crate::x509::wrap_in_sequence; |
33 | |
34 | /// Create a newtype wrapper around a given type. |
35 | /// |
36 | /// This is used to create newtypes for the various TLS message types which is used to wrap |
37 | /// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need |
38 | /// anything other than access to the underlying bytes. |
39 | macro_rules! wrapped_payload( |
40 | ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => { |
41 | $(#[$comment])* |
42 | #[derive(Clone, Debug)] |
43 | $vis struct $name($inner); |
44 | |
45 | impl From<Vec<u8>> for $name { |
46 | fn from(v: Vec<u8>) -> Self { |
47 | Self($inner::new(v)) |
48 | } |
49 | } |
50 | |
51 | impl AsRef<[u8]> for $name { |
52 | fn as_ref(&self) -> &[u8] { |
53 | self.0.0.as_slice() |
54 | } |
55 | } |
56 | |
57 | impl Codec<'_> for $name { |
58 | fn encode(&self, bytes: &mut Vec<u8>) { |
59 | self.0.encode(bytes); |
60 | } |
61 | |
62 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
63 | Ok(Self($inner::read(r)?)) |
64 | } |
65 | } |
66 | } |
67 | ); |
68 | |
69 | #[derive (Clone, Copy, Eq, PartialEq)] |
70 | pub struct Random(pub(crate) [u8; 32]); |
71 | |
72 | impl fmt::Debug for Random { |
73 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
74 | super::base::hex(f, &self.0) |
75 | } |
76 | } |
77 | |
78 | static HELLO_RETRY_REQUEST_RANDOM: Random = Random([ |
79 | 0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91, |
80 | 0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c, |
81 | ]); |
82 | |
83 | static ZERO_RANDOM: Random = Random([0u8; 32]); |
84 | |
85 | impl Codec<'_> for Random { |
86 | fn encode(&self, bytes: &mut Vec<u8>) { |
87 | bytes.extend_from_slice(&self.0); |
88 | } |
89 | |
90 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
91 | let Some(bytes: &[u8]) = r.take(length:32) else { |
92 | return Err(InvalidMessage::MissingData("Random" )); |
93 | }; |
94 | |
95 | let mut opaque: [u8; 32] = [0; 32]; |
96 | opaque.clone_from_slice(src:bytes); |
97 | Ok(Self(opaque)) |
98 | } |
99 | } |
100 | |
101 | impl Random { |
102 | pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> { |
103 | let mut data: [u8; 32] = [0u8; 32]; |
104 | secure_random.fill(&mut data)?; |
105 | Ok(Self(data)) |
106 | } |
107 | } |
108 | |
109 | impl From<[u8; 32]> for Random { |
110 | #[inline ] |
111 | fn from(bytes: [u8; 32]) -> Self { |
112 | Self(bytes) |
113 | } |
114 | } |
115 | |
116 | #[derive (Copy, Clone)] |
117 | pub struct SessionId { |
118 | len: usize, |
119 | data: [u8; 32], |
120 | } |
121 | |
122 | impl fmt::Debug for SessionId { |
123 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
124 | super::base::hex(f, &self.data[..self.len]) |
125 | } |
126 | } |
127 | |
128 | impl PartialEq for SessionId { |
129 | fn eq(&self, other: &Self) -> bool { |
130 | if self.len != other.len { |
131 | return false; |
132 | } |
133 | |
134 | let mut diff: u8 = 0u8; |
135 | for i: usize in 0..self.len { |
136 | diff |= self.data[i] ^ other.data[i]; |
137 | } |
138 | |
139 | diff == 0u8 |
140 | } |
141 | } |
142 | |
143 | impl Codec<'_> for SessionId { |
144 | fn encode(&self, bytes: &mut Vec<u8>) { |
145 | debug_assert!(self.len <= 32); |
146 | bytes.push(self.len as u8); |
147 | bytes.extend_from_slice(self.as_ref()); |
148 | } |
149 | |
150 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
151 | let len: usize = u8::read(r)? as usize; |
152 | if len > 32 { |
153 | return Err(InvalidMessage::TrailingData("SessionID" )); |
154 | } |
155 | |
156 | let Some(bytes: &[u8]) = r.take(length:len) else { |
157 | return Err(InvalidMessage::MissingData("SessionID" )); |
158 | }; |
159 | |
160 | let mut out: [u8; 32] = [0u8; 32]; |
161 | out[..len].clone_from_slice(&bytes[..len]); |
162 | Ok(Self { data: out, len }) |
163 | } |
164 | } |
165 | |
166 | impl SessionId { |
167 | pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> { |
168 | let mut data: [u8; 32] = [0u8; 32]; |
169 | secure_random.fill(&mut data)?; |
170 | Ok(Self { data, len: 32 }) |
171 | } |
172 | |
173 | pub(crate) fn empty() -> Self { |
174 | Self { |
175 | data: [0u8; 32], |
176 | len: 0, |
177 | } |
178 | } |
179 | |
180 | #[cfg (feature = "tls12" )] |
181 | pub(crate) fn is_empty(&self) -> bool { |
182 | self.len == 0 |
183 | } |
184 | } |
185 | |
186 | impl AsRef<[u8]> for SessionId { |
187 | fn as_ref(&self) -> &[u8] { |
188 | &self.data[..self.len] |
189 | } |
190 | } |
191 | |
192 | #[derive (Clone, Debug, PartialEq)] |
193 | pub struct UnknownExtension { |
194 | pub(crate) typ: ExtensionType, |
195 | pub(crate) payload: Payload<'static>, |
196 | } |
197 | |
198 | impl UnknownExtension { |
199 | fn encode(&self, bytes: &mut Vec<u8>) { |
200 | self.payload.encode(bytes); |
201 | } |
202 | |
203 | fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self { |
204 | let payload: Payload<'static> = Payload::read(r).into_owned(); |
205 | Self { typ, payload } |
206 | } |
207 | } |
208 | |
209 | impl TlsListElement for ECPointFormat { |
210 | const SIZE_LEN: ListLength = ListLength::U8; |
211 | } |
212 | |
213 | impl TlsListElement for NamedGroup { |
214 | const SIZE_LEN: ListLength = ListLength::U16; |
215 | } |
216 | |
217 | impl TlsListElement for SignatureScheme { |
218 | const SIZE_LEN: ListLength = ListLength::U16; |
219 | } |
220 | |
221 | #[derive (Clone, Debug)] |
222 | pub(crate) enum ServerNamePayload { |
223 | HostName(DnsName<'static>), |
224 | IpAddress(PayloadU16), |
225 | Unknown(Payload<'static>), |
226 | } |
227 | |
228 | impl ServerNamePayload { |
229 | pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self { |
230 | Self::HostName(hostname) |
231 | } |
232 | |
233 | fn read_hostname(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
234 | use pki_types::ServerName; |
235 | let raw = PayloadU16::read(r)?; |
236 | |
237 | match ServerName::try_from(raw.0.as_slice()) { |
238 | Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())), |
239 | Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)), |
240 | Ok(_) | Err(_) => { |
241 | warn!( |
242 | "Illegal SNI hostname received {:?}" , |
243 | String::from_utf8_lossy(&raw.0) |
244 | ); |
245 | Err(InvalidMessage::InvalidServerName) |
246 | } |
247 | } |
248 | } |
249 | |
250 | fn encode(&self, bytes: &mut Vec<u8>) { |
251 | match self { |
252 | Self::HostName(name) => { |
253 | (name.as_ref().len() as u16).encode(bytes); |
254 | bytes.extend_from_slice(name.as_ref().as_bytes()); |
255 | } |
256 | Self::IpAddress(r) => r.encode(bytes), |
257 | Self::Unknown(r) => r.encode(bytes), |
258 | } |
259 | } |
260 | } |
261 | |
262 | #[derive (Clone, Debug)] |
263 | pub struct ServerName { |
264 | pub(crate) typ: ServerNameType, |
265 | pub(crate) payload: ServerNamePayload, |
266 | } |
267 | |
268 | impl Codec<'_> for ServerName { |
269 | fn encode(&self, bytes: &mut Vec<u8>) { |
270 | self.typ.encode(bytes); |
271 | self.payload.encode(bytes); |
272 | } |
273 | |
274 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
275 | let typ: ServerNameType = ServerNameType::read(r)?; |
276 | |
277 | let payload: ServerNamePayload = match typ { |
278 | ServerNameType::HostName => ServerNamePayload::read_hostname(r)?, |
279 | _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()), |
280 | }; |
281 | |
282 | Ok(Self { typ, payload }) |
283 | } |
284 | } |
285 | |
286 | impl TlsListElement for ServerName { |
287 | const SIZE_LEN: ListLength = ListLength::U16; |
288 | } |
289 | |
290 | pub(crate) trait ConvertServerNameList { |
291 | fn has_duplicate_names_for_type(&self) -> bool; |
292 | fn single_hostname(&self) -> Option<DnsName<'_>>; |
293 | } |
294 | |
295 | impl ConvertServerNameList for [ServerName] { |
296 | /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type." |
297 | fn has_duplicate_names_for_type(&self) -> bool { |
298 | has_duplicates::<_, _, u8>(self.iter().map(|name: &ServerName| name.typ)) |
299 | } |
300 | |
301 | fn single_hostname(&self) -> Option<DnsName<'_>> { |
302 | fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> { |
303 | if let ServerNamePayload::HostName(dns: &DnsName<'static>) = &name.payload { |
304 | Some(dns.borrow()) |
305 | } else { |
306 | None |
307 | } |
308 | } |
309 | |
310 | self.iter() |
311 | .filter_map(only_dns_hostnames) |
312 | .next() |
313 | } |
314 | } |
315 | |
316 | wrapped_payload!(pub struct ProtocolName, PayloadU8,); |
317 | |
318 | impl TlsListElement for ProtocolName { |
319 | const SIZE_LEN: ListLength = ListLength::U16; |
320 | } |
321 | |
322 | pub(crate) trait ConvertProtocolNameList { |
323 | fn from_slices(names: &[&[u8]]) -> Self; |
324 | fn to_slices(&self) -> Vec<&[u8]>; |
325 | fn as_single_slice(&self) -> Option<&[u8]>; |
326 | } |
327 | |
328 | impl ConvertProtocolNameList for Vec<ProtocolName> { |
329 | fn from_slices(names: &[&[u8]]) -> Self { |
330 | let mut ret = Self::new(); |
331 | |
332 | for name in names { |
333 | ret.push(ProtocolName::from(name.to_vec())); |
334 | } |
335 | |
336 | ret |
337 | } |
338 | |
339 | fn to_slices(&self) -> Vec<&[u8]> { |
340 | self.iter() |
341 | .map(|proto| proto.as_ref()) |
342 | .collect::<Vec<&[u8]>>() |
343 | } |
344 | |
345 | fn as_single_slice(&self) -> Option<&[u8]> { |
346 | if self.len() == 1 { |
347 | Some(self[0].as_ref()) |
348 | } else { |
349 | None |
350 | } |
351 | } |
352 | } |
353 | |
354 | // --- TLS 1.3 Key shares --- |
355 | #[derive (Clone, Debug)] |
356 | pub struct KeyShareEntry { |
357 | pub(crate) group: NamedGroup, |
358 | pub(crate) payload: PayloadU16, |
359 | } |
360 | |
361 | impl KeyShareEntry { |
362 | pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self { |
363 | Self { |
364 | group, |
365 | payload: PayloadU16::new(bytes:payload.into()), |
366 | } |
367 | } |
368 | |
369 | pub fn group(&self) -> NamedGroup { |
370 | self.group |
371 | } |
372 | } |
373 | |
374 | impl Codec<'_> for KeyShareEntry { |
375 | fn encode(&self, bytes: &mut Vec<u8>) { |
376 | self.group.encode(bytes); |
377 | self.payload.encode(bytes); |
378 | } |
379 | |
380 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
381 | let group: NamedGroup = NamedGroup::read(r)?; |
382 | let payload: PayloadU16 = PayloadU16::read(r)?; |
383 | |
384 | Ok(Self { group, payload }) |
385 | } |
386 | } |
387 | |
388 | // --- TLS 1.3 PresharedKey offers --- |
389 | #[derive (Clone, Debug)] |
390 | pub(crate) struct PresharedKeyIdentity { |
391 | pub(crate) identity: PayloadU16, |
392 | pub(crate) obfuscated_ticket_age: u32, |
393 | } |
394 | |
395 | impl PresharedKeyIdentity { |
396 | pub(crate) fn new(id: Vec<u8>, age: u32) -> Self { |
397 | Self { |
398 | identity: PayloadU16::new(bytes:id), |
399 | obfuscated_ticket_age: age, |
400 | } |
401 | } |
402 | } |
403 | |
404 | impl Codec<'_> for PresharedKeyIdentity { |
405 | fn encode(&self, bytes: &mut Vec<u8>) { |
406 | self.identity.encode(bytes); |
407 | self.obfuscated_ticket_age.encode(bytes); |
408 | } |
409 | |
410 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
411 | Ok(Self { |
412 | identity: PayloadU16::read(r)?, |
413 | obfuscated_ticket_age: u32::read(r)?, |
414 | }) |
415 | } |
416 | } |
417 | |
418 | impl TlsListElement for PresharedKeyIdentity { |
419 | const SIZE_LEN: ListLength = ListLength::U16; |
420 | } |
421 | |
422 | wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,); |
423 | |
424 | impl TlsListElement for PresharedKeyBinder { |
425 | const SIZE_LEN: ListLength = ListLength::U16; |
426 | } |
427 | |
428 | #[derive (Clone, Debug)] |
429 | pub struct PresharedKeyOffer { |
430 | pub(crate) identities: Vec<PresharedKeyIdentity>, |
431 | pub(crate) binders: Vec<PresharedKeyBinder>, |
432 | } |
433 | |
434 | impl PresharedKeyOffer { |
435 | /// Make a new one with one entry. |
436 | pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self { |
437 | Self { |
438 | identities: vec![id], |
439 | binders: vec![PresharedKeyBinder::from(binder)], |
440 | } |
441 | } |
442 | } |
443 | |
444 | impl Codec<'_> for PresharedKeyOffer { |
445 | fn encode(&self, bytes: &mut Vec<u8>) { |
446 | self.identities.encode(bytes); |
447 | self.binders.encode(bytes); |
448 | } |
449 | |
450 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
451 | Ok(Self { |
452 | identities: Vec::read(r)?, |
453 | binders: Vec::read(r)?, |
454 | }) |
455 | } |
456 | } |
457 | |
458 | // --- RFC6066 certificate status request --- |
459 | wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,); |
460 | |
461 | impl TlsListElement for ResponderId { |
462 | const SIZE_LEN: ListLength = ListLength::U16; |
463 | } |
464 | |
465 | #[derive (Clone, Debug)] |
466 | pub struct OcspCertificateStatusRequest { |
467 | pub(crate) responder_ids: Vec<ResponderId>, |
468 | pub(crate) extensions: PayloadU16, |
469 | } |
470 | |
471 | impl Codec<'_> for OcspCertificateStatusRequest { |
472 | fn encode(&self, bytes: &mut Vec<u8>) { |
473 | CertificateStatusType::OCSP.encode(bytes); |
474 | self.responder_ids.encode(bytes); |
475 | self.extensions.encode(bytes); |
476 | } |
477 | |
478 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
479 | Ok(Self { |
480 | responder_ids: Vec::read(r)?, |
481 | extensions: PayloadU16::read(r)?, |
482 | }) |
483 | } |
484 | } |
485 | |
486 | #[derive (Clone, Debug)] |
487 | pub enum CertificateStatusRequest { |
488 | Ocsp(OcspCertificateStatusRequest), |
489 | Unknown((CertificateStatusType, Payload<'static>)), |
490 | } |
491 | |
492 | impl Codec<'_> for CertificateStatusRequest { |
493 | fn encode(&self, bytes: &mut Vec<u8>) { |
494 | match self { |
495 | Self::Ocsp(r) => r.encode(bytes), |
496 | Self::Unknown((typ, payload)) => { |
497 | typ.encode(bytes); |
498 | payload.encode(bytes); |
499 | } |
500 | } |
501 | } |
502 | |
503 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
504 | let typ = CertificateStatusType::read(r)?; |
505 | |
506 | match typ { |
507 | CertificateStatusType::OCSP => { |
508 | let ocsp_req = OcspCertificateStatusRequest::read(r)?; |
509 | Ok(Self::Ocsp(ocsp_req)) |
510 | } |
511 | _ => { |
512 | let data = Payload::read(r).into_owned(); |
513 | Ok(Self::Unknown((typ, data))) |
514 | } |
515 | } |
516 | } |
517 | } |
518 | |
519 | impl CertificateStatusRequest { |
520 | pub(crate) fn build_ocsp() -> Self { |
521 | let ocsp: OcspCertificateStatusRequest = OcspCertificateStatusRequest { |
522 | responder_ids: Vec::new(), |
523 | extensions: PayloadU16::empty(), |
524 | }; |
525 | Self::Ocsp(ocsp) |
526 | } |
527 | } |
528 | |
529 | // --- |
530 | |
531 | impl TlsListElement for PSKKeyExchangeMode { |
532 | const SIZE_LEN: ListLength = ListLength::U8; |
533 | } |
534 | |
535 | impl TlsListElement for KeyShareEntry { |
536 | const SIZE_LEN: ListLength = ListLength::U16; |
537 | } |
538 | |
539 | impl TlsListElement for ProtocolVersion { |
540 | const SIZE_LEN: ListLength = ListLength::U8; |
541 | } |
542 | |
543 | impl TlsListElement for CertificateType { |
544 | const SIZE_LEN: ListLength = ListLength::U8; |
545 | } |
546 | |
547 | impl TlsListElement for CertificateCompressionAlgorithm { |
548 | const SIZE_LEN: ListLength = ListLength::U8; |
549 | } |
550 | |
551 | #[derive (Clone, Debug)] |
552 | pub enum ClientExtension { |
553 | EcPointFormats(Vec<ECPointFormat>), |
554 | NamedGroups(Vec<NamedGroup>), |
555 | SignatureAlgorithms(Vec<SignatureScheme>), |
556 | ServerName(Vec<ServerName>), |
557 | SessionTicket(ClientSessionTicket), |
558 | Protocols(Vec<ProtocolName>), |
559 | SupportedVersions(Vec<ProtocolVersion>), |
560 | KeyShare(Vec<KeyShareEntry>), |
561 | PresharedKeyModes(Vec<PSKKeyExchangeMode>), |
562 | PresharedKey(PresharedKeyOffer), |
563 | Cookie(PayloadU16), |
564 | ExtendedMasterSecretRequest, |
565 | CertificateStatusRequest(CertificateStatusRequest), |
566 | ServerCertTypes(Vec<CertificateType>), |
567 | ClientCertTypes(Vec<CertificateType>), |
568 | TransportParameters(Vec<u8>), |
569 | TransportParametersDraft(Vec<u8>), |
570 | EarlyData, |
571 | CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>), |
572 | EncryptedClientHello(EncryptedClientHello), |
573 | EncryptedClientHelloOuterExtensions(Vec<ExtensionType>), |
574 | AuthorityNames(Vec<DistinguishedName>), |
575 | Unknown(UnknownExtension), |
576 | } |
577 | |
578 | impl ClientExtension { |
579 | pub(crate) fn ext_type(&self) -> ExtensionType { |
580 | match self { |
581 | Self::EcPointFormats(_) => ExtensionType::ECPointFormats, |
582 | Self::NamedGroups(_) => ExtensionType::EllipticCurves, |
583 | Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms, |
584 | Self::ServerName(_) => ExtensionType::ServerName, |
585 | Self::SessionTicket(_) => ExtensionType::SessionTicket, |
586 | Self::Protocols(_) => ExtensionType::ALProtocolNegotiation, |
587 | Self::SupportedVersions(_) => ExtensionType::SupportedVersions, |
588 | Self::KeyShare(_) => ExtensionType::KeyShare, |
589 | Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes, |
590 | Self::PresharedKey(_) => ExtensionType::PreSharedKey, |
591 | Self::Cookie(_) => ExtensionType::Cookie, |
592 | Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret, |
593 | Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest, |
594 | Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType, |
595 | Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType, |
596 | Self::TransportParameters(_) => ExtensionType::TransportParameters, |
597 | Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft, |
598 | Self::EarlyData => ExtensionType::EarlyData, |
599 | Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate, |
600 | Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello, |
601 | Self::EncryptedClientHelloOuterExtensions(_) => { |
602 | ExtensionType::EncryptedClientHelloOuterExtensions |
603 | } |
604 | Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities, |
605 | Self::Unknown(r) => r.typ, |
606 | } |
607 | } |
608 | } |
609 | |
610 | impl Codec<'_> for ClientExtension { |
611 | fn encode(&self, bytes: &mut Vec<u8>) { |
612 | self.ext_type().encode(bytes); |
613 | |
614 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
615 | match self { |
616 | Self::EcPointFormats(r) => r.encode(nested.buf), |
617 | Self::NamedGroups(r) => r.encode(nested.buf), |
618 | Self::SignatureAlgorithms(r) => r.encode(nested.buf), |
619 | Self::ServerName(r) => r.encode(nested.buf), |
620 | Self::SessionTicket(ClientSessionTicket::Request) |
621 | | Self::ExtendedMasterSecretRequest |
622 | | Self::EarlyData => {} |
623 | Self::SessionTicket(ClientSessionTicket::Offer(r)) => r.encode(nested.buf), |
624 | Self::Protocols(r) => r.encode(nested.buf), |
625 | Self::SupportedVersions(r) => r.encode(nested.buf), |
626 | Self::KeyShare(r) => r.encode(nested.buf), |
627 | Self::PresharedKeyModes(r) => r.encode(nested.buf), |
628 | Self::PresharedKey(r) => r.encode(nested.buf), |
629 | Self::Cookie(r) => r.encode(nested.buf), |
630 | Self::CertificateStatusRequest(r) => r.encode(nested.buf), |
631 | Self::ClientCertTypes(r) => r.encode(nested.buf), |
632 | Self::ServerCertTypes(r) => r.encode(nested.buf), |
633 | Self::TransportParameters(r) | Self::TransportParametersDraft(r) => { |
634 | nested.buf.extend_from_slice(r); |
635 | } |
636 | Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf), |
637 | Self::EncryptedClientHello(r) => r.encode(nested.buf), |
638 | Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf), |
639 | Self::AuthorityNames(r) => r.encode(nested.buf), |
640 | Self::Unknown(r) => r.encode(nested.buf), |
641 | } |
642 | } |
643 | |
644 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
645 | let typ = ExtensionType::read(r)?; |
646 | let len = u16::read(r)? as usize; |
647 | let mut sub = r.sub(len)?; |
648 | |
649 | let ext = match typ { |
650 | ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?), |
651 | ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?), |
652 | ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?), |
653 | ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?), |
654 | ExtensionType::SessionTicket => { |
655 | if sub.any_left() { |
656 | let contents = Payload::read(&mut sub).into_owned(); |
657 | Self::SessionTicket(ClientSessionTicket::Offer(contents)) |
658 | } else { |
659 | Self::SessionTicket(ClientSessionTicket::Request) |
660 | } |
661 | } |
662 | ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?), |
663 | ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?), |
664 | ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?), |
665 | ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?), |
666 | ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?), |
667 | ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?), |
668 | ExtensionType::ExtendedMasterSecret if !sub.any_left() => { |
669 | Self::ExtendedMasterSecretRequest |
670 | } |
671 | ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?), |
672 | ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?), |
673 | ExtensionType::StatusRequest => { |
674 | let csr = CertificateStatusRequest::read(&mut sub)?; |
675 | Self::CertificateStatusRequest(csr) |
676 | } |
677 | ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()), |
678 | ExtensionType::TransportParametersDraft => { |
679 | Self::TransportParametersDraft(sub.rest().to_vec()) |
680 | } |
681 | ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData, |
682 | ExtensionType::CompressCertificate => { |
683 | Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?) |
684 | } |
685 | ExtensionType::EncryptedClientHelloOuterExtensions => { |
686 | Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?) |
687 | } |
688 | ExtensionType::CertificateAuthorities => Self::AuthorityNames(Vec::read(&mut sub)?), |
689 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
690 | }; |
691 | |
692 | sub.expect_empty("ClientExtension" ) |
693 | .map(|_| ext) |
694 | } |
695 | } |
696 | |
697 | fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> { |
698 | let dns_name_str: &str = dns_name.as_ref(); |
699 | |
700 | // RFC6066: "The hostname is represented as a byte string using |
701 | // ASCII encoding without a trailing dot" |
702 | if dns_name_str.ends_with('.' ) { |
703 | let trimmed: &str = &dns_name_str[0..dns_name_str.len() - 1]; |
704 | DnsNameDnsName<'_>::try_from(trimmed) |
705 | .unwrap() |
706 | .to_owned() |
707 | } else { |
708 | dns_name.to_owned() |
709 | } |
710 | } |
711 | |
712 | impl ClientExtension { |
713 | /// Make a basic SNI ServerNameRequest quoting `hostname`. |
714 | pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self { |
715 | let name: ServerName = ServerName { |
716 | typ: ServerNameType::HostName, |
717 | payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)), |
718 | }; |
719 | |
720 | Self::ServerName(vec![name]) |
721 | } |
722 | } |
723 | |
724 | #[derive (Clone, Debug)] |
725 | pub enum ClientSessionTicket { |
726 | Request, |
727 | Offer(Payload<'static>), |
728 | } |
729 | |
730 | #[derive (Clone, Debug)] |
731 | pub enum ServerExtension { |
732 | EcPointFormats(Vec<ECPointFormat>), |
733 | ServerNameAck, |
734 | SessionTicketAck, |
735 | RenegotiationInfo(PayloadU8), |
736 | Protocols(Vec<ProtocolName>), |
737 | KeyShare(KeyShareEntry), |
738 | PresharedKey(u16), |
739 | ExtendedMasterSecretAck, |
740 | CertificateStatusAck, |
741 | ServerCertType(CertificateType), |
742 | ClientCertType(CertificateType), |
743 | SupportedVersions(ProtocolVersion), |
744 | TransportParameters(Vec<u8>), |
745 | TransportParametersDraft(Vec<u8>), |
746 | EarlyData, |
747 | EncryptedClientHello(ServerEncryptedClientHello), |
748 | Unknown(UnknownExtension), |
749 | } |
750 | |
751 | impl ServerExtension { |
752 | pub(crate) fn ext_type(&self) -> ExtensionType { |
753 | match self { |
754 | Self::EcPointFormats(_) => ExtensionType::ECPointFormats, |
755 | Self::ServerNameAck => ExtensionType::ServerName, |
756 | Self::SessionTicketAck => ExtensionType::SessionTicket, |
757 | Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo, |
758 | Self::Protocols(_) => ExtensionType::ALProtocolNegotiation, |
759 | Self::KeyShare(_) => ExtensionType::KeyShare, |
760 | Self::PresharedKey(_) => ExtensionType::PreSharedKey, |
761 | Self::ClientCertType(_) => ExtensionType::ClientCertificateType, |
762 | Self::ServerCertType(_) => ExtensionType::ServerCertificateType, |
763 | Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret, |
764 | Self::CertificateStatusAck => ExtensionType::StatusRequest, |
765 | Self::SupportedVersions(_) => ExtensionType::SupportedVersions, |
766 | Self::TransportParameters(_) => ExtensionType::TransportParameters, |
767 | Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft, |
768 | Self::EarlyData => ExtensionType::EarlyData, |
769 | Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello, |
770 | Self::Unknown(r: &UnknownExtension) => r.typ, |
771 | } |
772 | } |
773 | } |
774 | |
775 | impl Codec<'_> for ServerExtension { |
776 | fn encode(&self, bytes: &mut Vec<u8>) { |
777 | self.ext_type().encode(bytes); |
778 | |
779 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
780 | match self { |
781 | Self::EcPointFormats(r) => r.encode(nested.buf), |
782 | Self::ServerNameAck |
783 | | Self::SessionTicketAck |
784 | | Self::ExtendedMasterSecretAck |
785 | | Self::CertificateStatusAck |
786 | | Self::EarlyData => {} |
787 | Self::RenegotiationInfo(r) => r.encode(nested.buf), |
788 | Self::Protocols(r) => r.encode(nested.buf), |
789 | Self::KeyShare(r) => r.encode(nested.buf), |
790 | Self::PresharedKey(r) => r.encode(nested.buf), |
791 | Self::ClientCertType(r) => r.encode(nested.buf), |
792 | Self::ServerCertType(r) => r.encode(nested.buf), |
793 | Self::SupportedVersions(r) => r.encode(nested.buf), |
794 | Self::TransportParameters(r) | Self::TransportParametersDraft(r) => { |
795 | nested.buf.extend_from_slice(r); |
796 | } |
797 | Self::EncryptedClientHello(r) => r.encode(nested.buf), |
798 | Self::Unknown(r) => r.encode(nested.buf), |
799 | } |
800 | } |
801 | |
802 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
803 | let typ = ExtensionType::read(r)?; |
804 | let len = u16::read(r)? as usize; |
805 | let mut sub = r.sub(len)?; |
806 | |
807 | let ext = match typ { |
808 | ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?), |
809 | ExtensionType::ServerName => Self::ServerNameAck, |
810 | ExtensionType::SessionTicket => Self::SessionTicketAck, |
811 | ExtensionType::StatusRequest => Self::CertificateStatusAck, |
812 | ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?), |
813 | ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?), |
814 | ExtensionType::ClientCertificateType => { |
815 | Self::ClientCertType(CertificateType::read(&mut sub)?) |
816 | } |
817 | ExtensionType::ServerCertificateType => { |
818 | Self::ServerCertType(CertificateType::read(&mut sub)?) |
819 | } |
820 | ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?), |
821 | ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?), |
822 | ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck, |
823 | ExtensionType::SupportedVersions => { |
824 | Self::SupportedVersions(ProtocolVersion::read(&mut sub)?) |
825 | } |
826 | ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()), |
827 | ExtensionType::TransportParametersDraft => { |
828 | Self::TransportParametersDraft(sub.rest().to_vec()) |
829 | } |
830 | ExtensionType::EarlyData => Self::EarlyData, |
831 | ExtensionType::EncryptedClientHello => { |
832 | Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?) |
833 | } |
834 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
835 | }; |
836 | |
837 | sub.expect_empty("ServerExtension" ) |
838 | .map(|_| ext) |
839 | } |
840 | } |
841 | |
842 | impl ServerExtension { |
843 | pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self { |
844 | Self::Protocols(Vec::from_slices(names:proto)) |
845 | } |
846 | |
847 | #[cfg (feature = "tls12" )] |
848 | pub(crate) fn make_empty_renegotiation_info() -> Self { |
849 | let empty: Vec = Vec::new(); |
850 | Self::RenegotiationInfo(PayloadU8::new(bytes:empty)) |
851 | } |
852 | } |
853 | |
854 | #[derive (Clone, Debug)] |
855 | pub struct ClientHelloPayload { |
856 | pub client_version: ProtocolVersion, |
857 | pub random: Random, |
858 | pub session_id: SessionId, |
859 | pub cipher_suites: Vec<CipherSuite>, |
860 | pub compression_methods: Vec<Compression>, |
861 | pub extensions: Vec<ClientExtension>, |
862 | } |
863 | |
864 | impl Codec<'_> for ClientHelloPayload { |
865 | fn encode(&self, bytes: &mut Vec<u8>) { |
866 | self.payload_encode(bytes, Encoding::Standard) |
867 | } |
868 | |
869 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
870 | let mut ret = Self { |
871 | client_version: ProtocolVersion::read(r)?, |
872 | random: Random::read(r)?, |
873 | session_id: SessionId::read(r)?, |
874 | cipher_suites: Vec::read(r)?, |
875 | compression_methods: Vec::read(r)?, |
876 | extensions: Vec::new(), |
877 | }; |
878 | |
879 | if r.any_left() { |
880 | ret.extensions = Vec::read(r)?; |
881 | } |
882 | |
883 | match (r.any_left(), ret.extensions.is_empty()) { |
884 | (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload" )), |
885 | (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload" )), |
886 | _ => Ok(ret), |
887 | } |
888 | } |
889 | } |
890 | |
891 | impl TlsListElement for CipherSuite { |
892 | const SIZE_LEN: ListLength = ListLength::U16; |
893 | } |
894 | |
895 | impl TlsListElement for Compression { |
896 | const SIZE_LEN: ListLength = ListLength::U8; |
897 | } |
898 | |
899 | impl TlsListElement for ClientExtension { |
900 | const SIZE_LEN: ListLength = ListLength::U16; |
901 | } |
902 | |
903 | impl TlsListElement for ExtensionType { |
904 | const SIZE_LEN: ListLength = ListLength::U8; |
905 | } |
906 | |
907 | impl ClientHelloPayload { |
908 | pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> { |
909 | let mut bytes = Vec::new(); |
910 | self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress }); |
911 | bytes |
912 | } |
913 | |
914 | pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) { |
915 | self.client_version.encode(bytes); |
916 | self.random.encode(bytes); |
917 | |
918 | match purpose { |
919 | // SessionID is required to be empty in the encoded inner client hello. |
920 | Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes), |
921 | _ => self.session_id.encode(bytes), |
922 | } |
923 | |
924 | self.cipher_suites.encode(bytes); |
925 | self.compression_methods.encode(bytes); |
926 | |
927 | let to_compress = match purpose { |
928 | // Compressed extensions must be replaced in the encoded inner client hello. |
929 | Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress, |
930 | _ => { |
931 | if !self.extensions.is_empty() { |
932 | self.extensions.encode(bytes); |
933 | } |
934 | return; |
935 | } |
936 | }; |
937 | |
938 | // Safety: not empty check in match guard. |
939 | let first_compressed_type = *to_compress.first().unwrap(); |
940 | |
941 | // Compressed extensions are in a contiguous range and must be replaced |
942 | // with a marker extension. |
943 | let compressed_start_idx = self |
944 | .extensions |
945 | .iter() |
946 | .position(|ext| ext.ext_type() == first_compressed_type); |
947 | let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len()); |
948 | let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress); |
949 | |
950 | let exts = self |
951 | .extensions |
952 | .iter() |
953 | .enumerate() |
954 | .filter_map(|(i, ext)| { |
955 | if Some(i) == compressed_start_idx { |
956 | Some(&marker_ext) |
957 | } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx { |
958 | None |
959 | } else { |
960 | Some(ext) |
961 | } |
962 | }); |
963 | |
964 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
965 | for ext in exts { |
966 | ext.encode(nested.buf); |
967 | } |
968 | } |
969 | |
970 | /// Returns true if there is more than one extension of a given |
971 | /// type. |
972 | pub(crate) fn has_duplicate_extension(&self) -> bool { |
973 | has_duplicates::<_, _, u16>( |
974 | self.extensions |
975 | .iter() |
976 | .map(|ext| ext.ext_type()), |
977 | ) |
978 | } |
979 | |
980 | pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> { |
981 | self.extensions |
982 | .iter() |
983 | .find(|x| x.ext_type() == ext) |
984 | } |
985 | |
986 | pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> { |
987 | let ext = self.find_extension(ExtensionType::ServerName)?; |
988 | match ext { |
989 | // Does this comply with RFC6066? |
990 | // |
991 | // [RFC6066][] specifies that literal IP addresses are illegal in |
992 | // `ServerName`s with a `name_type` of `host_name`. |
993 | // |
994 | // Some clients incorrectly send such extensions: we choose to |
995 | // successfully parse these (into `ServerNamePayload::IpAddress`) |
996 | // but then act like the client sent no `server_name` extension. |
997 | // |
998 | // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3 |
999 | ClientExtension::ServerName(req) |
1000 | if !req |
1001 | .iter() |
1002 | .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) => |
1003 | { |
1004 | Some(req) |
1005 | } |
1006 | _ => None, |
1007 | } |
1008 | } |
1009 | |
1010 | pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> { |
1011 | let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?; |
1012 | match ext { |
1013 | ClientExtension::SignatureAlgorithms(req) => Some(req), |
1014 | _ => None, |
1015 | } |
1016 | } |
1017 | |
1018 | pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> { |
1019 | let ext = self.find_extension(ExtensionType::EllipticCurves)?; |
1020 | match ext { |
1021 | ClientExtension::NamedGroups(req) => Some(req), |
1022 | _ => None, |
1023 | } |
1024 | } |
1025 | |
1026 | #[cfg (feature = "tls12" )] |
1027 | pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> { |
1028 | let ext = self.find_extension(ExtensionType::ECPointFormats)?; |
1029 | match ext { |
1030 | ClientExtension::EcPointFormats(req) => Some(req), |
1031 | _ => None, |
1032 | } |
1033 | } |
1034 | |
1035 | pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> { |
1036 | let ext = self.find_extension(ExtensionType::ServerCertificateType)?; |
1037 | match ext { |
1038 | ClientExtension::ServerCertTypes(req) => Some(req), |
1039 | _ => None, |
1040 | } |
1041 | } |
1042 | |
1043 | pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> { |
1044 | let ext = self.find_extension(ExtensionType::ClientCertificateType)?; |
1045 | match ext { |
1046 | ClientExtension::ClientCertTypes(req) => Some(req), |
1047 | _ => None, |
1048 | } |
1049 | } |
1050 | |
1051 | pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> { |
1052 | let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?; |
1053 | match ext { |
1054 | ClientExtension::Protocols(req) => Some(req), |
1055 | _ => None, |
1056 | } |
1057 | } |
1058 | |
1059 | pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> { |
1060 | let ext = self |
1061 | .find_extension(ExtensionType::TransportParameters) |
1062 | .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?; |
1063 | match ext { |
1064 | ClientExtension::TransportParameters(bytes) |
1065 | | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()), |
1066 | _ => None, |
1067 | } |
1068 | } |
1069 | |
1070 | #[cfg (feature = "tls12" )] |
1071 | pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> { |
1072 | self.find_extension(ExtensionType::SessionTicket) |
1073 | } |
1074 | |
1075 | pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> { |
1076 | let ext = self.find_extension(ExtensionType::SupportedVersions)?; |
1077 | match ext { |
1078 | ClientExtension::SupportedVersions(vers) => Some(vers), |
1079 | _ => None, |
1080 | } |
1081 | } |
1082 | |
1083 | pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> { |
1084 | let ext = self.find_extension(ExtensionType::KeyShare)?; |
1085 | match ext { |
1086 | ClientExtension::KeyShare(shares) => Some(shares), |
1087 | _ => None, |
1088 | } |
1089 | } |
1090 | |
1091 | pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool { |
1092 | self.keyshare_extension() |
1093 | .map(|entries| { |
1094 | has_duplicates::<_, _, u16>( |
1095 | entries |
1096 | .iter() |
1097 | .map(|kse| u16::from(kse.group)), |
1098 | ) |
1099 | }) |
1100 | .unwrap_or_default() |
1101 | } |
1102 | |
1103 | pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> { |
1104 | let ext = self.find_extension(ExtensionType::PreSharedKey)?; |
1105 | match ext { |
1106 | ClientExtension::PresharedKey(psk) => Some(psk), |
1107 | _ => None, |
1108 | } |
1109 | } |
1110 | |
1111 | pub(crate) fn check_psk_ext_is_last(&self) -> bool { |
1112 | self.extensions |
1113 | .last() |
1114 | .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey) |
1115 | } |
1116 | |
1117 | pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> { |
1118 | let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?; |
1119 | match ext { |
1120 | ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes), |
1121 | _ => None, |
1122 | } |
1123 | } |
1124 | |
1125 | pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool { |
1126 | self.psk_modes() |
1127 | .map(|modes| modes.contains(&mode)) |
1128 | .unwrap_or(false) |
1129 | } |
1130 | |
1131 | pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) { |
1132 | let last_extension = self.extensions.last_mut(); |
1133 | if let Some(ClientExtension::PresharedKey(offer)) = last_extension { |
1134 | offer.binders[0] = PresharedKeyBinder::from(binder.into()); |
1135 | } |
1136 | } |
1137 | |
1138 | #[cfg (feature = "tls12" )] |
1139 | pub(crate) fn ems_support_offered(&self) -> bool { |
1140 | self.find_extension(ExtensionType::ExtendedMasterSecret) |
1141 | .is_some() |
1142 | } |
1143 | |
1144 | pub(crate) fn early_data_extension_offered(&self) -> bool { |
1145 | self.find_extension(ExtensionType::EarlyData) |
1146 | .is_some() |
1147 | } |
1148 | |
1149 | pub(crate) fn certificate_compression_extension( |
1150 | &self, |
1151 | ) -> Option<&[CertificateCompressionAlgorithm]> { |
1152 | let ext = self.find_extension(ExtensionType::CompressCertificate)?; |
1153 | match ext { |
1154 | ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs), |
1155 | _ => None, |
1156 | } |
1157 | } |
1158 | |
1159 | pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool { |
1160 | if let Some(algs) = self.certificate_compression_extension() { |
1161 | has_duplicates::<_, _, u16>(algs.iter().cloned()) |
1162 | } else { |
1163 | false |
1164 | } |
1165 | } |
1166 | |
1167 | pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> { |
1168 | match self.find_extension(ExtensionType::CertificateAuthorities)? { |
1169 | ClientExtension::AuthorityNames(ext) => Some(ext), |
1170 | _ => unreachable!("extension type checked" ), |
1171 | } |
1172 | } |
1173 | } |
1174 | |
1175 | #[derive (Clone, Debug)] |
1176 | pub(crate) enum HelloRetryExtension { |
1177 | KeyShare(NamedGroup), |
1178 | Cookie(PayloadU16), |
1179 | SupportedVersions(ProtocolVersion), |
1180 | EchHelloRetryRequest(Vec<u8>), |
1181 | Unknown(UnknownExtension), |
1182 | } |
1183 | |
1184 | impl HelloRetryExtension { |
1185 | pub(crate) fn ext_type(&self) -> ExtensionType { |
1186 | match self { |
1187 | Self::KeyShare(_) => ExtensionType::KeyShare, |
1188 | Self::Cookie(_) => ExtensionType::Cookie, |
1189 | Self::SupportedVersions(_) => ExtensionType::SupportedVersions, |
1190 | Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello, |
1191 | Self::Unknown(r: &UnknownExtension) => r.typ, |
1192 | } |
1193 | } |
1194 | } |
1195 | |
1196 | impl Codec<'_> for HelloRetryExtension { |
1197 | fn encode(&self, bytes: &mut Vec<u8>) { |
1198 | self.ext_type().encode(bytes); |
1199 | |
1200 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
1201 | match self { |
1202 | Self::KeyShare(r) => r.encode(nested.buf), |
1203 | Self::Cookie(r) => r.encode(nested.buf), |
1204 | Self::SupportedVersions(r) => r.encode(nested.buf), |
1205 | Self::EchHelloRetryRequest(r) => { |
1206 | nested.buf.extend_from_slice(r); |
1207 | } |
1208 | Self::Unknown(r) => r.encode(nested.buf), |
1209 | } |
1210 | } |
1211 | |
1212 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1213 | let typ = ExtensionType::read(r)?; |
1214 | let len = u16::read(r)? as usize; |
1215 | let mut sub = r.sub(len)?; |
1216 | |
1217 | let ext = match typ { |
1218 | ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?), |
1219 | ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?), |
1220 | ExtensionType::SupportedVersions => { |
1221 | Self::SupportedVersions(ProtocolVersion::read(&mut sub)?) |
1222 | } |
1223 | ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()), |
1224 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
1225 | }; |
1226 | |
1227 | sub.expect_empty("HelloRetryExtension" ) |
1228 | .map(|_| ext) |
1229 | } |
1230 | } |
1231 | |
1232 | impl TlsListElement for HelloRetryExtension { |
1233 | const SIZE_LEN: ListLength = ListLength::U16; |
1234 | } |
1235 | |
1236 | #[derive (Clone, Debug)] |
1237 | pub struct HelloRetryRequest { |
1238 | pub(crate) legacy_version: ProtocolVersion, |
1239 | pub session_id: SessionId, |
1240 | pub(crate) cipher_suite: CipherSuite, |
1241 | pub(crate) extensions: Vec<HelloRetryExtension>, |
1242 | } |
1243 | |
1244 | impl Codec<'_> for HelloRetryRequest { |
1245 | fn encode(&self, bytes: &mut Vec<u8>) { |
1246 | self.payload_encode(bytes, purpose:Encoding::Standard) |
1247 | } |
1248 | |
1249 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1250 | let session_id: SessionId = SessionId::read(r)?; |
1251 | let cipher_suite: CipherSuite = CipherSuite::read(r)?; |
1252 | let compression: Compression = Compression::read(r)?; |
1253 | |
1254 | if compression != Compression::Null { |
1255 | return Err(InvalidMessage::UnsupportedCompression); |
1256 | } |
1257 | |
1258 | Ok(Self { |
1259 | legacy_version: ProtocolVersion::Unknown(0), |
1260 | session_id, |
1261 | cipher_suite, |
1262 | extensions: Vec::read(r)?, |
1263 | }) |
1264 | } |
1265 | } |
1266 | |
1267 | impl HelloRetryRequest { |
1268 | /// Returns true if there is more than one extension of a given |
1269 | /// type. |
1270 | pub(crate) fn has_duplicate_extension(&self) -> bool { |
1271 | has_duplicates::<_, _, u16>( |
1272 | self.extensions |
1273 | .iter() |
1274 | .map(|ext| ext.ext_type()), |
1275 | ) |
1276 | } |
1277 | |
1278 | pub(crate) fn has_unknown_extension(&self) -> bool { |
1279 | self.extensions.iter().any(|ext| { |
1280 | ext.ext_type() != ExtensionType::KeyShare |
1281 | && ext.ext_type() != ExtensionType::SupportedVersions |
1282 | && ext.ext_type() != ExtensionType::Cookie |
1283 | && ext.ext_type() != ExtensionType::EncryptedClientHello |
1284 | }) |
1285 | } |
1286 | |
1287 | fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> { |
1288 | self.extensions |
1289 | .iter() |
1290 | .find(|x| x.ext_type() == ext) |
1291 | } |
1292 | |
1293 | pub fn requested_key_share_group(&self) -> Option<NamedGroup> { |
1294 | let ext = self.find_extension(ExtensionType::KeyShare)?; |
1295 | match ext { |
1296 | HelloRetryExtension::KeyShare(grp) => Some(*grp), |
1297 | _ => None, |
1298 | } |
1299 | } |
1300 | |
1301 | pub(crate) fn cookie(&self) -> Option<&PayloadU16> { |
1302 | let ext = self.find_extension(ExtensionType::Cookie)?; |
1303 | match ext { |
1304 | HelloRetryExtension::Cookie(ck) => Some(ck), |
1305 | _ => None, |
1306 | } |
1307 | } |
1308 | |
1309 | pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> { |
1310 | let ext = self.find_extension(ExtensionType::SupportedVersions)?; |
1311 | match ext { |
1312 | HelloRetryExtension::SupportedVersions(ver) => Some(*ver), |
1313 | _ => None, |
1314 | } |
1315 | } |
1316 | |
1317 | pub(crate) fn ech(&self) -> Option<&Vec<u8>> { |
1318 | let ext = self.find_extension(ExtensionType::EncryptedClientHello)?; |
1319 | match ext { |
1320 | HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech), |
1321 | _ => None, |
1322 | } |
1323 | } |
1324 | |
1325 | fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) { |
1326 | self.legacy_version.encode(bytes); |
1327 | HELLO_RETRY_REQUEST_RANDOM.encode(bytes); |
1328 | self.session_id.encode(bytes); |
1329 | self.cipher_suite.encode(bytes); |
1330 | Compression::Null.encode(bytes); |
1331 | |
1332 | match purpose { |
1333 | // For the purpose of ECH confirmation, the Encrypted Client Hello extension |
1334 | // must have its payload replaced by 8 zero bytes. |
1335 | // |
1336 | // See draft-ietf-tls-esni-18 7.2.1: |
1337 | // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2> |
1338 | Encoding::EchConfirmation => { |
1339 | let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
1340 | for ext in &self.extensions { |
1341 | match ext.ext_type() { |
1342 | ExtensionType::EncryptedClientHello => { |
1343 | HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8]) |
1344 | .encode(extensions.buf); |
1345 | } |
1346 | _ => { |
1347 | ext.encode(extensions.buf); |
1348 | } |
1349 | } |
1350 | } |
1351 | } |
1352 | _ => { |
1353 | self.extensions.encode(bytes); |
1354 | } |
1355 | } |
1356 | } |
1357 | } |
1358 | |
1359 | #[derive (Clone, Debug)] |
1360 | pub struct ServerHelloPayload { |
1361 | pub extensions: Vec<ServerExtension>, |
1362 | pub(crate) legacy_version: ProtocolVersion, |
1363 | pub(crate) random: Random, |
1364 | pub(crate) session_id: SessionId, |
1365 | pub(crate) cipher_suite: CipherSuite, |
1366 | pub(crate) compression_method: Compression, |
1367 | } |
1368 | |
1369 | impl Codec<'_> for ServerHelloPayload { |
1370 | fn encode(&self, bytes: &mut Vec<u8>) { |
1371 | self.payload_encode(bytes, Encoding::Standard) |
1372 | } |
1373 | |
1374 | // minus version and random, which have already been read. |
1375 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1376 | let session_id = SessionId::read(r)?; |
1377 | let suite = CipherSuite::read(r)?; |
1378 | let compression = Compression::read(r)?; |
1379 | |
1380 | // RFC5246: |
1381 | // "The presence of extensions can be detected by determining whether |
1382 | // there are bytes following the compression_method field at the end of |
1383 | // the ServerHello." |
1384 | let extensions = if r.any_left() { Vec::read(r)? } else { vec![] }; |
1385 | |
1386 | let ret = Self { |
1387 | legacy_version: ProtocolVersion::Unknown(0), |
1388 | random: ZERO_RANDOM, |
1389 | session_id, |
1390 | cipher_suite: suite, |
1391 | compression_method: compression, |
1392 | extensions, |
1393 | }; |
1394 | |
1395 | r.expect_empty("ServerHelloPayload" ) |
1396 | .map(|_| ret) |
1397 | } |
1398 | } |
1399 | |
1400 | impl HasServerExtensions for ServerHelloPayload { |
1401 | fn extensions(&self) -> &[ServerExtension] { |
1402 | &self.extensions |
1403 | } |
1404 | } |
1405 | |
1406 | impl ServerHelloPayload { |
1407 | pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> { |
1408 | let ext = self.find_extension(ExtensionType::KeyShare)?; |
1409 | match ext { |
1410 | ServerExtension::KeyShare(share) => Some(share), |
1411 | _ => None, |
1412 | } |
1413 | } |
1414 | |
1415 | pub(crate) fn psk_index(&self) -> Option<u16> { |
1416 | let ext = self.find_extension(ExtensionType::PreSharedKey)?; |
1417 | match ext { |
1418 | ServerExtension::PresharedKey(index) => Some(*index), |
1419 | _ => None, |
1420 | } |
1421 | } |
1422 | |
1423 | pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> { |
1424 | let ext = self.find_extension(ExtensionType::ECPointFormats)?; |
1425 | match ext { |
1426 | ServerExtension::EcPointFormats(fmts) => Some(fmts), |
1427 | _ => None, |
1428 | } |
1429 | } |
1430 | |
1431 | #[cfg (feature = "tls12" )] |
1432 | pub(crate) fn ems_support_acked(&self) -> bool { |
1433 | self.find_extension(ExtensionType::ExtendedMasterSecret) |
1434 | .is_some() |
1435 | } |
1436 | |
1437 | pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> { |
1438 | let ext = self.find_extension(ExtensionType::SupportedVersions)?; |
1439 | match ext { |
1440 | ServerExtension::SupportedVersions(vers) => Some(*vers), |
1441 | _ => None, |
1442 | } |
1443 | } |
1444 | |
1445 | fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) { |
1446 | self.legacy_version.encode(bytes); |
1447 | |
1448 | match encoding { |
1449 | // When encoding a ServerHello for ECH confirmation, the random value |
1450 | // has the last 8 bytes zeroed out. |
1451 | Encoding::EchConfirmation => { |
1452 | // Indexing safety: self.random is 32 bytes long by definition. |
1453 | let rand_vec = self.random.get_encoding(); |
1454 | bytes.extend_from_slice(&rand_vec.as_slice()[..24]); |
1455 | bytes.extend_from_slice(&[0u8; 8]); |
1456 | } |
1457 | _ => self.random.encode(bytes), |
1458 | } |
1459 | |
1460 | self.session_id.encode(bytes); |
1461 | self.cipher_suite.encode(bytes); |
1462 | self.compression_method.encode(bytes); |
1463 | |
1464 | if !self.extensions.is_empty() { |
1465 | self.extensions.encode(bytes); |
1466 | } |
1467 | } |
1468 | } |
1469 | |
1470 | #[derive (Clone, Default, Debug)] |
1471 | pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>); |
1472 | |
1473 | impl CertificateChain<'_> { |
1474 | pub(crate) fn into_owned(self) -> CertificateChain<'static> { |
1475 | CertificateChain( |
1476 | self.0 |
1477 | .into_iter() |
1478 | .map(|c: CertificateDer<'_>| c.into_owned()) |
1479 | .collect(), |
1480 | ) |
1481 | } |
1482 | } |
1483 | |
1484 | impl<'a> Codec<'a> for CertificateChain<'a> { |
1485 | fn encode(&self, bytes: &mut Vec<u8>) { |
1486 | Vec::encode(&self.0, bytes) |
1487 | } |
1488 | |
1489 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
1490 | Vec::read(r).map(Self) |
1491 | } |
1492 | } |
1493 | |
1494 | impl<'a> Deref for CertificateChain<'a> { |
1495 | type Target = [CertificateDer<'a>]; |
1496 | |
1497 | fn deref(&self) -> &[CertificateDer<'a>] { |
1498 | &self.0 |
1499 | } |
1500 | } |
1501 | |
1502 | impl TlsListElement for CertificateDer<'_> { |
1503 | const SIZE_LEN: ListLength = ListLength::U24 { |
1504 | max: CERTIFICATE_MAX_SIZE_LIMIT, |
1505 | error: InvalidMessage::CertificatePayloadTooLarge, |
1506 | }; |
1507 | } |
1508 | |
1509 | /// TLS has a 16MB size limit on any handshake message, |
1510 | /// plus a 16MB limit on any given certificate. |
1511 | /// |
1512 | /// We contract that to 64KB to limit the amount of memory allocation |
1513 | /// that is directly controllable by the peer. |
1514 | pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000; |
1515 | |
1516 | #[derive (Debug)] |
1517 | pub(crate) enum CertificateExtension<'a> { |
1518 | CertificateStatus(CertificateStatus<'a>), |
1519 | Unknown(UnknownExtension), |
1520 | } |
1521 | |
1522 | impl CertificateExtension<'_> { |
1523 | pub(crate) fn ext_type(&self) -> ExtensionType { |
1524 | match self { |
1525 | Self::CertificateStatus(_) => ExtensionType::StatusRequest, |
1526 | Self::Unknown(r: &UnknownExtension) => r.typ, |
1527 | } |
1528 | } |
1529 | |
1530 | pub(crate) fn cert_status(&self) -> Option<&[u8]> { |
1531 | match self { |
1532 | Self::CertificateStatus(cs: &CertificateStatus<'_>) => Some(cs.ocsp_response.0.bytes()), |
1533 | _ => None, |
1534 | } |
1535 | } |
1536 | |
1537 | pub(crate) fn into_owned(self) -> CertificateExtension<'static> { |
1538 | match self { |
1539 | Self::CertificateStatus(st: CertificateStatus<'_>) => CertificateExtension::CertificateStatus(st.into_owned()), |
1540 | Self::Unknown(unk: UnknownExtension) => CertificateExtension::Unknown(unk), |
1541 | } |
1542 | } |
1543 | } |
1544 | |
1545 | impl<'a> Codec<'a> for CertificateExtension<'a> { |
1546 | fn encode(&self, bytes: &mut Vec<u8>) { |
1547 | self.ext_type().encode(bytes); |
1548 | |
1549 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
1550 | match self { |
1551 | Self::CertificateStatus(r) => r.encode(nested.buf), |
1552 | Self::Unknown(r) => r.encode(nested.buf), |
1553 | } |
1554 | } |
1555 | |
1556 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
1557 | let typ = ExtensionType::read(r)?; |
1558 | let len = u16::read(r)? as usize; |
1559 | let mut sub = r.sub(len)?; |
1560 | |
1561 | let ext = match typ { |
1562 | ExtensionType::StatusRequest => { |
1563 | let st = CertificateStatus::read(&mut sub)?; |
1564 | Self::CertificateStatus(st) |
1565 | } |
1566 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
1567 | }; |
1568 | |
1569 | sub.expect_empty("CertificateExtension" ) |
1570 | .map(|_| ext) |
1571 | } |
1572 | } |
1573 | |
1574 | impl TlsListElement for CertificateExtension<'_> { |
1575 | const SIZE_LEN: ListLength = ListLength::U16; |
1576 | } |
1577 | |
1578 | #[derive (Debug)] |
1579 | pub(crate) struct CertificateEntry<'a> { |
1580 | pub(crate) cert: CertificateDer<'a>, |
1581 | pub(crate) exts: Vec<CertificateExtension<'a>>, |
1582 | } |
1583 | |
1584 | impl<'a> Codec<'a> for CertificateEntry<'a> { |
1585 | fn encode(&self, bytes: &mut Vec<u8>) { |
1586 | self.cert.encode(bytes); |
1587 | self.exts.encode(bytes); |
1588 | } |
1589 | |
1590 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
1591 | Ok(Self { |
1592 | cert: CertificateDer::read(r)?, |
1593 | exts: Vec::read(r)?, |
1594 | }) |
1595 | } |
1596 | } |
1597 | |
1598 | impl<'a> CertificateEntry<'a> { |
1599 | pub(crate) fn new(cert: CertificateDer<'a>) -> Self { |
1600 | Self { |
1601 | cert, |
1602 | exts: Vec::new(), |
1603 | } |
1604 | } |
1605 | |
1606 | pub(crate) fn into_owned(self) -> CertificateEntry<'static> { |
1607 | CertificateEntry { |
1608 | cert: self.cert.into_owned(), |
1609 | exts: self |
1610 | .exts |
1611 | .into_iter() |
1612 | .map(CertificateExtension::into_owned) |
1613 | .collect(), |
1614 | } |
1615 | } |
1616 | |
1617 | pub(crate) fn has_duplicate_extension(&self) -> bool { |
1618 | has_duplicates::<_, _, u16>( |
1619 | self.exts |
1620 | .iter() |
1621 | .map(|ext| ext.ext_type()), |
1622 | ) |
1623 | } |
1624 | |
1625 | pub(crate) fn has_unknown_extension(&self) -> bool { |
1626 | self.exts |
1627 | .iter() |
1628 | .any(|ext| ext.ext_type() != ExtensionType::StatusRequest) |
1629 | } |
1630 | |
1631 | pub(crate) fn ocsp_response(&self) -> Option<&[u8]> { |
1632 | self.exts |
1633 | .iter() |
1634 | .find(|ext| ext.ext_type() == ExtensionType::StatusRequest) |
1635 | .and_then(CertificateExtension::cert_status) |
1636 | } |
1637 | } |
1638 | |
1639 | impl TlsListElement for CertificateEntry<'_> { |
1640 | const SIZE_LEN: ListLength = ListLength::U24 { |
1641 | max: CERTIFICATE_MAX_SIZE_LIMIT, |
1642 | error: InvalidMessage::CertificatePayloadTooLarge, |
1643 | }; |
1644 | } |
1645 | |
1646 | #[derive (Debug)] |
1647 | pub struct CertificatePayloadTls13<'a> { |
1648 | pub(crate) context: PayloadU8, |
1649 | pub(crate) entries: Vec<CertificateEntry<'a>>, |
1650 | } |
1651 | |
1652 | impl<'a> Codec<'a> for CertificatePayloadTls13<'a> { |
1653 | fn encode(&self, bytes: &mut Vec<u8>) { |
1654 | self.context.encode(bytes); |
1655 | self.entries.encode(bytes); |
1656 | } |
1657 | |
1658 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
1659 | Ok(Self { |
1660 | context: PayloadU8::read(r)?, |
1661 | entries: Vec::read(r)?, |
1662 | }) |
1663 | } |
1664 | } |
1665 | |
1666 | impl<'a> CertificatePayloadTls13<'a> { |
1667 | pub(crate) fn new( |
1668 | certs: impl Iterator<Item = &'a CertificateDer<'a>>, |
1669 | ocsp_response: Option<&'a [u8]>, |
1670 | ) -> Self { |
1671 | Self { |
1672 | context: PayloadU8::empty(), |
1673 | entries: certs |
1674 | // zip certificate iterator with `ocsp_response` followed by |
1675 | // an infinite-length iterator of `None`. |
1676 | .zip( |
1677 | ocsp_response |
1678 | .into_iter() |
1679 | .map(Some) |
1680 | .chain(iter::repeat(None)), |
1681 | ) |
1682 | .map(|(cert, ocsp)| { |
1683 | let mut e = CertificateEntry::new(cert.clone()); |
1684 | if let Some(ocsp) = ocsp { |
1685 | e.exts |
1686 | .push(CertificateExtension::CertificateStatus( |
1687 | CertificateStatus::new(ocsp), |
1688 | )); |
1689 | } |
1690 | e |
1691 | }) |
1692 | .collect(), |
1693 | } |
1694 | } |
1695 | |
1696 | pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> { |
1697 | CertificatePayloadTls13 { |
1698 | context: self.context, |
1699 | entries: self |
1700 | .entries |
1701 | .into_iter() |
1702 | .map(CertificateEntry::into_owned) |
1703 | .collect(), |
1704 | } |
1705 | } |
1706 | |
1707 | pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool { |
1708 | for entry in &self.entries { |
1709 | if entry.has_duplicate_extension() { |
1710 | return true; |
1711 | } |
1712 | } |
1713 | |
1714 | false |
1715 | } |
1716 | |
1717 | pub(crate) fn any_entry_has_unknown_extension(&self) -> bool { |
1718 | for entry in &self.entries { |
1719 | if entry.has_unknown_extension() { |
1720 | return true; |
1721 | } |
1722 | } |
1723 | |
1724 | false |
1725 | } |
1726 | |
1727 | pub(crate) fn any_entry_has_extension(&self) -> bool { |
1728 | for entry in &self.entries { |
1729 | if !entry.exts.is_empty() { |
1730 | return true; |
1731 | } |
1732 | } |
1733 | |
1734 | false |
1735 | } |
1736 | |
1737 | pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> { |
1738 | self.entries |
1739 | .first() |
1740 | .and_then(CertificateEntry::ocsp_response) |
1741 | .map(|resp| resp.to_vec()) |
1742 | .unwrap_or_default() |
1743 | } |
1744 | |
1745 | pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> { |
1746 | CertificateChain( |
1747 | self.entries |
1748 | .into_iter() |
1749 | .map(|e| e.cert) |
1750 | .collect(), |
1751 | ) |
1752 | } |
1753 | } |
1754 | |
1755 | /// Describes supported key exchange mechanisms. |
1756 | #[derive (Clone, Copy, Debug, PartialEq)] |
1757 | #[non_exhaustive ] |
1758 | pub enum KeyExchangeAlgorithm { |
1759 | /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]). |
1760 | /// |
1761 | /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919 |
1762 | DHE, |
1763 | /// Key exchange performed via elliptic curve Diffie-Hellman. |
1764 | ECDHE, |
1765 | } |
1766 | |
1767 | pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] = |
1768 | &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE]; |
1769 | |
1770 | // We don't support arbitrary curves. It's a terrible |
1771 | // idea and unnecessary attack surface. Please, |
1772 | // get a grip. |
1773 | #[derive (Debug)] |
1774 | pub(crate) struct EcParameters { |
1775 | pub(crate) curve_type: ECCurveType, |
1776 | pub(crate) named_group: NamedGroup, |
1777 | } |
1778 | |
1779 | impl Codec<'_> for EcParameters { |
1780 | fn encode(&self, bytes: &mut Vec<u8>) { |
1781 | self.curve_type.encode(bytes); |
1782 | self.named_group.encode(bytes); |
1783 | } |
1784 | |
1785 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1786 | let ct: ECCurveType = ECCurveType::read(r)?; |
1787 | if ct != ECCurveType::NamedCurve { |
1788 | return Err(InvalidMessage::UnsupportedCurveType); |
1789 | } |
1790 | |
1791 | let grp: NamedGroup = NamedGroup::read(r)?; |
1792 | |
1793 | Ok(Self { |
1794 | curve_type: ct, |
1795 | named_group: grp, |
1796 | }) |
1797 | } |
1798 | } |
1799 | |
1800 | #[cfg (feature = "tls12" )] |
1801 | pub(crate) trait KxDecode<'a>: fmt::Debug + Sized { |
1802 | /// Decode a key exchange message given the key_exchange `algo` |
1803 | fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>; |
1804 | } |
1805 | |
1806 | #[cfg (feature = "tls12" )] |
1807 | #[derive (Debug)] |
1808 | pub(crate) enum ClientKeyExchangeParams { |
1809 | Ecdh(ClientEcdhParams), |
1810 | Dh(ClientDhParams), |
1811 | } |
1812 | |
1813 | #[cfg (feature = "tls12" )] |
1814 | impl ClientKeyExchangeParams { |
1815 | pub(crate) fn pub_key(&self) -> &[u8] { |
1816 | match self { |
1817 | Self::Ecdh(ecdh: &ClientEcdhParams) => &ecdh.public.0, |
1818 | Self::Dh(dh: &ClientDhParams) => &dh.public.0, |
1819 | } |
1820 | } |
1821 | |
1822 | pub(crate) fn encode(&self, buf: &mut Vec<u8>) { |
1823 | match self { |
1824 | Self::Ecdh(ecdh: &ClientEcdhParams) => ecdh.encode(bytes:buf), |
1825 | Self::Dh(dh: &ClientDhParams) => dh.encode(bytes:buf), |
1826 | } |
1827 | } |
1828 | } |
1829 | |
1830 | #[cfg (feature = "tls12" )] |
1831 | impl KxDecode<'_> for ClientKeyExchangeParams { |
1832 | fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> { |
1833 | use KeyExchangeAlgorithm::*; |
1834 | Ok(match algo { |
1835 | ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?), |
1836 | DHE => Self::Dh(ClientDhParams::read(r)?), |
1837 | }) |
1838 | } |
1839 | } |
1840 | |
1841 | #[cfg (feature = "tls12" )] |
1842 | #[derive (Debug)] |
1843 | pub(crate) struct ClientEcdhParams { |
1844 | pub(crate) public: PayloadU8, |
1845 | } |
1846 | |
1847 | #[cfg (feature = "tls12" )] |
1848 | impl Codec<'_> for ClientEcdhParams { |
1849 | fn encode(&self, bytes: &mut Vec<u8>) { |
1850 | self.public.encode(bytes); |
1851 | } |
1852 | |
1853 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1854 | let pb: PayloadU8 = PayloadU8::read(r)?; |
1855 | Ok(Self { public: pb }) |
1856 | } |
1857 | } |
1858 | |
1859 | #[cfg (feature = "tls12" )] |
1860 | #[derive (Debug)] |
1861 | pub(crate) struct ClientDhParams { |
1862 | pub(crate) public: PayloadU16, |
1863 | } |
1864 | |
1865 | #[cfg (feature = "tls12" )] |
1866 | impl Codec<'_> for ClientDhParams { |
1867 | fn encode(&self, bytes: &mut Vec<u8>) { |
1868 | self.public.encode(bytes); |
1869 | } |
1870 | |
1871 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1872 | Ok(Self { |
1873 | public: PayloadU16::read(r)?, |
1874 | }) |
1875 | } |
1876 | } |
1877 | |
1878 | #[derive (Debug)] |
1879 | pub(crate) struct ServerEcdhParams { |
1880 | pub(crate) curve_params: EcParameters, |
1881 | pub(crate) public: PayloadU8, |
1882 | } |
1883 | |
1884 | impl ServerEcdhParams { |
1885 | #[cfg (feature = "tls12" )] |
1886 | pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self { |
1887 | Self { |
1888 | curve_params: EcParameters { |
1889 | curve_type: ECCurveType::NamedCurve, |
1890 | named_group: kx.group(), |
1891 | }, |
1892 | public: PayloadU8::new(bytes:kx.pub_key().to_vec()), |
1893 | } |
1894 | } |
1895 | } |
1896 | |
1897 | impl Codec<'_> for ServerEcdhParams { |
1898 | fn encode(&self, bytes: &mut Vec<u8>) { |
1899 | self.curve_params.encode(bytes); |
1900 | self.public.encode(bytes); |
1901 | } |
1902 | |
1903 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1904 | let cp: EcParameters = EcParameters::read(r)?; |
1905 | let pb: PayloadU8 = PayloadU8::read(r)?; |
1906 | |
1907 | Ok(Self { |
1908 | curve_params: cp, |
1909 | public: pb, |
1910 | }) |
1911 | } |
1912 | } |
1913 | |
1914 | #[derive (Debug)] |
1915 | #[allow (non_snake_case)] |
1916 | pub(crate) struct ServerDhParams { |
1917 | pub(crate) dh_p: PayloadU16, |
1918 | pub(crate) dh_g: PayloadU16, |
1919 | pub(crate) dh_Ys: PayloadU16, |
1920 | } |
1921 | |
1922 | impl ServerDhParams { |
1923 | #[cfg (feature = "tls12" )] |
1924 | pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self { |
1925 | let Some(params: FfdheGroup<'static>) = kx.ffdhe_group() else { |
1926 | panic!("invalid NamedGroup for DHE key exchange: {:?}" , kx.group()); |
1927 | }; |
1928 | |
1929 | Self { |
1930 | dh_p: PayloadU16::new(bytes:params.p.to_vec()), |
1931 | dh_g: PayloadU16::new(bytes:params.g.to_vec()), |
1932 | dh_Ys: PayloadU16::new(bytes:kx.pub_key().to_vec()), |
1933 | } |
1934 | } |
1935 | |
1936 | #[cfg (feature = "tls12" )] |
1937 | pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> { |
1938 | FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0) |
1939 | } |
1940 | } |
1941 | |
1942 | impl Codec<'_> for ServerDhParams { |
1943 | fn encode(&self, bytes: &mut Vec<u8>) { |
1944 | self.dh_p.encode(bytes); |
1945 | self.dh_g.encode(bytes); |
1946 | self.dh_Ys.encode(bytes); |
1947 | } |
1948 | |
1949 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
1950 | Ok(Self { |
1951 | dh_p: PayloadU16::read(r)?, |
1952 | dh_g: PayloadU16::read(r)?, |
1953 | dh_Ys: PayloadU16::read(r)?, |
1954 | }) |
1955 | } |
1956 | } |
1957 | |
1958 | #[allow (dead_code)] |
1959 | #[derive (Debug)] |
1960 | pub(crate) enum ServerKeyExchangeParams { |
1961 | Ecdh(ServerEcdhParams), |
1962 | Dh(ServerDhParams), |
1963 | } |
1964 | |
1965 | impl ServerKeyExchangeParams { |
1966 | #[cfg (feature = "tls12" )] |
1967 | pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self { |
1968 | match kx.group().key_exchange_algorithm() { |
1969 | KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)), |
1970 | KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)), |
1971 | } |
1972 | } |
1973 | |
1974 | #[cfg (feature = "tls12" )] |
1975 | pub(crate) fn pub_key(&self) -> &[u8] { |
1976 | match self { |
1977 | Self::Ecdh(ecdh: &ServerEcdhParams) => &ecdh.public.0, |
1978 | Self::Dh(dh: &ServerDhParams) => &dh.dh_Ys.0, |
1979 | } |
1980 | } |
1981 | |
1982 | pub(crate) fn encode(&self, buf: &mut Vec<u8>) { |
1983 | match self { |
1984 | Self::Ecdh(ecdh: &ServerEcdhParams) => ecdh.encode(bytes:buf), |
1985 | Self::Dh(dh: &ServerDhParams) => dh.encode(bytes:buf), |
1986 | } |
1987 | } |
1988 | } |
1989 | |
1990 | #[cfg (feature = "tls12" )] |
1991 | impl KxDecode<'_> for ServerKeyExchangeParams { |
1992 | fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> { |
1993 | use KeyExchangeAlgorithm::*; |
1994 | Ok(match algo { |
1995 | ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?), |
1996 | DHE => Self::Dh(ServerDhParams::read(r)?), |
1997 | }) |
1998 | } |
1999 | } |
2000 | |
2001 | #[derive (Debug)] |
2002 | pub struct ServerKeyExchange { |
2003 | pub(crate) params: ServerKeyExchangeParams, |
2004 | pub(crate) dss: DigitallySignedStruct, |
2005 | } |
2006 | |
2007 | impl ServerKeyExchange { |
2008 | pub fn encode(&self, buf: &mut Vec<u8>) { |
2009 | self.params.encode(buf); |
2010 | self.dss.encode(bytes:buf); |
2011 | } |
2012 | } |
2013 | |
2014 | #[derive (Debug)] |
2015 | pub enum ServerKeyExchangePayload { |
2016 | Known(ServerKeyExchange), |
2017 | Unknown(Payload<'static>), |
2018 | } |
2019 | |
2020 | impl From<ServerKeyExchange> for ServerKeyExchangePayload { |
2021 | fn from(value: ServerKeyExchange) -> Self { |
2022 | Self::Known(value) |
2023 | } |
2024 | } |
2025 | |
2026 | impl Codec<'_> for ServerKeyExchangePayload { |
2027 | fn encode(&self, bytes: &mut Vec<u8>) { |
2028 | match self { |
2029 | Self::Known(x: &ServerKeyExchange) => x.encode(buf:bytes), |
2030 | Self::Unknown(x: &Payload<'static>) => x.encode(bytes), |
2031 | } |
2032 | } |
2033 | |
2034 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2035 | // read as Unknown, fully parse when we know the |
2036 | // KeyExchangeAlgorithm |
2037 | Ok(Self::Unknown(Payload::read(r).into_owned())) |
2038 | } |
2039 | } |
2040 | |
2041 | impl ServerKeyExchangePayload { |
2042 | #[cfg (feature = "tls12" )] |
2043 | pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> { |
2044 | if let Self::Unknown(unk: &Payload<'static>) = self { |
2045 | let mut rd: Reader<'_> = Reader::init(unk.bytes()); |
2046 | |
2047 | let result: ServerKeyExchange = ServerKeyExchange { |
2048 | params: ServerKeyExchangeParams::decode(&mut rd, algo:kxa).ok()?, |
2049 | dss: DigitallySignedStruct::read(&mut rd).ok()?, |
2050 | }; |
2051 | |
2052 | if !rd.any_left() { |
2053 | return Some(result); |
2054 | }; |
2055 | } |
2056 | |
2057 | None |
2058 | } |
2059 | } |
2060 | |
2061 | // -- EncryptedExtensions (TLS1.3 only) -- |
2062 | |
2063 | impl TlsListElement for ServerExtension { |
2064 | const SIZE_LEN: ListLength = ListLength::U16; |
2065 | } |
2066 | |
2067 | pub(crate) trait HasServerExtensions { |
2068 | fn extensions(&self) -> &[ServerExtension]; |
2069 | |
2070 | /// Returns true if there is more than one extension of a given |
2071 | /// type. |
2072 | fn has_duplicate_extension(&self) -> bool { |
2073 | has_duplicates::<_, _, u16>( |
2074 | self.extensions() |
2075 | .iter() |
2076 | .map(|ext| ext.ext_type()), |
2077 | ) |
2078 | } |
2079 | |
2080 | fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> { |
2081 | self.extensions() |
2082 | .iter() |
2083 | .find(|x| x.ext_type() == ext) |
2084 | } |
2085 | |
2086 | fn alpn_protocol(&self) -> Option<&[u8]> { |
2087 | let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?; |
2088 | match ext { |
2089 | ServerExtension::Protocols(protos) => protos.as_single_slice(), |
2090 | _ => None, |
2091 | } |
2092 | } |
2093 | |
2094 | fn server_cert_type(&self) -> Option<&CertificateType> { |
2095 | let ext = self.find_extension(ExtensionType::ServerCertificateType)?; |
2096 | match ext { |
2097 | ServerExtension::ServerCertType(req) => Some(req), |
2098 | _ => None, |
2099 | } |
2100 | } |
2101 | |
2102 | fn client_cert_type(&self) -> Option<&CertificateType> { |
2103 | let ext = self.find_extension(ExtensionType::ClientCertificateType)?; |
2104 | match ext { |
2105 | ServerExtension::ClientCertType(req) => Some(req), |
2106 | _ => None, |
2107 | } |
2108 | } |
2109 | |
2110 | fn quic_params_extension(&self) -> Option<Vec<u8>> { |
2111 | let ext = self |
2112 | .find_extension(ExtensionType::TransportParameters) |
2113 | .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?; |
2114 | match ext { |
2115 | ServerExtension::TransportParameters(bytes) |
2116 | | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()), |
2117 | _ => None, |
2118 | } |
2119 | } |
2120 | |
2121 | fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> { |
2122 | let ext = self.find_extension(ExtensionType::EncryptedClientHello)?; |
2123 | match ext { |
2124 | ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()), |
2125 | _ => None, |
2126 | } |
2127 | } |
2128 | |
2129 | fn early_data_extension_offered(&self) -> bool { |
2130 | self.find_extension(ExtensionType::EarlyData) |
2131 | .is_some() |
2132 | } |
2133 | } |
2134 | |
2135 | impl HasServerExtensions for Vec<ServerExtension> { |
2136 | fn extensions(&self) -> &[ServerExtension] { |
2137 | self |
2138 | } |
2139 | } |
2140 | |
2141 | impl TlsListElement for ClientCertificateType { |
2142 | const SIZE_LEN: ListLength = ListLength::U8; |
2143 | } |
2144 | |
2145 | wrapped_payload!( |
2146 | /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types. |
2147 | /// |
2148 | /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6) |
2149 | /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116). |
2150 | /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html). |
2151 | /// |
2152 | /// ```ignore |
2153 | /// for name in distinguished_names { |
2154 | /// use x509_parser::prelude::FromDer; |
2155 | /// println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1); |
2156 | /// } |
2157 | /// ``` |
2158 | pub struct DistinguishedName, |
2159 | PayloadU16, |
2160 | ); |
2161 | |
2162 | impl DistinguishedName { |
2163 | /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding. |
2164 | /// |
2165 | /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html). |
2166 | /// |
2167 | /// ```ignore |
2168 | /// use x509_parser::prelude::FromDer; |
2169 | /// println!("{}" , x509_parser::x509::X509Name::from_der(dn.as_ref())?.1); |
2170 | /// ``` |
2171 | pub fn in_sequence(bytes: &[u8]) -> Self { |
2172 | Self(PayloadU16::new(bytes:wrap_in_sequence(bytes))) |
2173 | } |
2174 | } |
2175 | |
2176 | impl TlsListElement for DistinguishedName { |
2177 | const SIZE_LEN: ListLength = ListLength::U16; |
2178 | } |
2179 | |
2180 | #[derive (Debug)] |
2181 | pub struct CertificateRequestPayload { |
2182 | pub(crate) certtypes: Vec<ClientCertificateType>, |
2183 | pub(crate) sigschemes: Vec<SignatureScheme>, |
2184 | pub(crate) canames: Vec<DistinguishedName>, |
2185 | } |
2186 | |
2187 | impl Codec<'_> for CertificateRequestPayload { |
2188 | fn encode(&self, bytes: &mut Vec<u8>) { |
2189 | self.certtypes.encode(bytes); |
2190 | self.sigschemes.encode(bytes); |
2191 | self.canames.encode(bytes); |
2192 | } |
2193 | |
2194 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2195 | let certtypes: Vec = Vec::read(r)?; |
2196 | let sigschemes: Vec = Vec::read(r)?; |
2197 | let canames: Vec = Vec::read(r)?; |
2198 | |
2199 | if sigschemes.is_empty() { |
2200 | warn!("meaningless CertificateRequest message" ); |
2201 | Err(InvalidMessage::NoSignatureSchemes) |
2202 | } else { |
2203 | Ok(Self { |
2204 | certtypes, |
2205 | sigschemes, |
2206 | canames, |
2207 | }) |
2208 | } |
2209 | } |
2210 | } |
2211 | |
2212 | #[derive (Debug)] |
2213 | pub(crate) enum CertReqExtension { |
2214 | SignatureAlgorithms(Vec<SignatureScheme>), |
2215 | AuthorityNames(Vec<DistinguishedName>), |
2216 | CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>), |
2217 | Unknown(UnknownExtension), |
2218 | } |
2219 | |
2220 | impl CertReqExtension { |
2221 | pub(crate) fn ext_type(&self) -> ExtensionType { |
2222 | match self { |
2223 | Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms, |
2224 | Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities, |
2225 | Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate, |
2226 | Self::Unknown(r: &UnknownExtension) => r.typ, |
2227 | } |
2228 | } |
2229 | } |
2230 | |
2231 | impl Codec<'_> for CertReqExtension { |
2232 | fn encode(&self, bytes: &mut Vec<u8>) { |
2233 | self.ext_type().encode(bytes); |
2234 | |
2235 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
2236 | match self { |
2237 | Self::SignatureAlgorithms(r) => r.encode(nested.buf), |
2238 | Self::AuthorityNames(r) => r.encode(nested.buf), |
2239 | Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf), |
2240 | Self::Unknown(r) => r.encode(nested.buf), |
2241 | } |
2242 | } |
2243 | |
2244 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2245 | let typ = ExtensionType::read(r)?; |
2246 | let len = u16::read(r)? as usize; |
2247 | let mut sub = r.sub(len)?; |
2248 | |
2249 | let ext = match typ { |
2250 | ExtensionType::SignatureAlgorithms => { |
2251 | let schemes = Vec::read(&mut sub)?; |
2252 | if schemes.is_empty() { |
2253 | return Err(InvalidMessage::NoSignatureSchemes); |
2254 | } |
2255 | Self::SignatureAlgorithms(schemes) |
2256 | } |
2257 | ExtensionType::CertificateAuthorities => { |
2258 | let cas = Vec::read(&mut sub)?; |
2259 | Self::AuthorityNames(cas) |
2260 | } |
2261 | ExtensionType::CompressCertificate => { |
2262 | Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?) |
2263 | } |
2264 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
2265 | }; |
2266 | |
2267 | sub.expect_empty("CertReqExtension" ) |
2268 | .map(|_| ext) |
2269 | } |
2270 | } |
2271 | |
2272 | impl TlsListElement for CertReqExtension { |
2273 | const SIZE_LEN: ListLength = ListLength::U16; |
2274 | } |
2275 | |
2276 | #[derive (Debug)] |
2277 | pub struct CertificateRequestPayloadTls13 { |
2278 | pub(crate) context: PayloadU8, |
2279 | pub(crate) extensions: Vec<CertReqExtension>, |
2280 | } |
2281 | |
2282 | impl Codec<'_> for CertificateRequestPayloadTls13 { |
2283 | fn encode(&self, bytes: &mut Vec<u8>) { |
2284 | self.context.encode(bytes); |
2285 | self.extensions.encode(bytes); |
2286 | } |
2287 | |
2288 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2289 | let context: PayloadU8 = PayloadU8::read(r)?; |
2290 | let extensions: Vec = Vec::read(r)?; |
2291 | |
2292 | Ok(Self { |
2293 | context, |
2294 | extensions, |
2295 | }) |
2296 | } |
2297 | } |
2298 | |
2299 | impl CertificateRequestPayloadTls13 { |
2300 | pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> { |
2301 | self.extensions |
2302 | .iter() |
2303 | .find(|x| x.ext_type() == ext) |
2304 | } |
2305 | |
2306 | pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> { |
2307 | let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?; |
2308 | match ext { |
2309 | CertReqExtension::SignatureAlgorithms(sa) => Some(sa), |
2310 | _ => None, |
2311 | } |
2312 | } |
2313 | |
2314 | pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> { |
2315 | let ext = self.find_extension(ExtensionType::CertificateAuthorities)?; |
2316 | match ext { |
2317 | CertReqExtension::AuthorityNames(an) => Some(an), |
2318 | _ => None, |
2319 | } |
2320 | } |
2321 | |
2322 | pub(crate) fn certificate_compression_extension( |
2323 | &self, |
2324 | ) -> Option<&[CertificateCompressionAlgorithm]> { |
2325 | let ext = self.find_extension(ExtensionType::CompressCertificate)?; |
2326 | match ext { |
2327 | CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps), |
2328 | _ => None, |
2329 | } |
2330 | } |
2331 | } |
2332 | |
2333 | // -- NewSessionTicket -- |
2334 | #[derive (Debug)] |
2335 | pub struct NewSessionTicketPayload { |
2336 | pub(crate) lifetime_hint: u32, |
2337 | // Tickets can be large (KB), so we deserialise this straight |
2338 | // into an Arc, so it can be passed directly into the client's |
2339 | // session object without copying. |
2340 | pub(crate) ticket: Arc<PayloadU16>, |
2341 | } |
2342 | |
2343 | impl NewSessionTicketPayload { |
2344 | #[cfg (feature = "tls12" )] |
2345 | pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self { |
2346 | Self { |
2347 | lifetime_hint, |
2348 | ticket: Arc::new(data:PayloadU16::new(bytes:ticket)), |
2349 | } |
2350 | } |
2351 | } |
2352 | |
2353 | impl Codec<'_> for NewSessionTicketPayload { |
2354 | fn encode(&self, bytes: &mut Vec<u8>) { |
2355 | self.lifetime_hint.encode(bytes); |
2356 | self.ticket.encode(bytes); |
2357 | } |
2358 | |
2359 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2360 | let lifetime: u32 = u32::read(r)?; |
2361 | let ticket: Arc = Arc::new(data:PayloadU16::read(r)?); |
2362 | |
2363 | Ok(Self { |
2364 | lifetime_hint: lifetime, |
2365 | ticket, |
2366 | }) |
2367 | } |
2368 | } |
2369 | |
2370 | // -- NewSessionTicket electric boogaloo -- |
2371 | #[derive (Debug)] |
2372 | pub(crate) enum NewSessionTicketExtension { |
2373 | EarlyData(u32), |
2374 | Unknown(UnknownExtension), |
2375 | } |
2376 | |
2377 | impl NewSessionTicketExtension { |
2378 | pub(crate) fn ext_type(&self) -> ExtensionType { |
2379 | match self { |
2380 | Self::EarlyData(_) => ExtensionType::EarlyData, |
2381 | Self::Unknown(r: &UnknownExtension) => r.typ, |
2382 | } |
2383 | } |
2384 | } |
2385 | |
2386 | impl Codec<'_> for NewSessionTicketExtension { |
2387 | fn encode(&self, bytes: &mut Vec<u8>) { |
2388 | self.ext_type().encode(bytes); |
2389 | |
2390 | let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
2391 | match self { |
2392 | Self::EarlyData(r) => r.encode(nested.buf), |
2393 | Self::Unknown(r) => r.encode(nested.buf), |
2394 | } |
2395 | } |
2396 | |
2397 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2398 | let typ = ExtensionType::read(r)?; |
2399 | let len = u16::read(r)? as usize; |
2400 | let mut sub = r.sub(len)?; |
2401 | |
2402 | let ext = match typ { |
2403 | ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?), |
2404 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
2405 | }; |
2406 | |
2407 | sub.expect_empty("NewSessionTicketExtension" ) |
2408 | .map(|_| ext) |
2409 | } |
2410 | } |
2411 | |
2412 | impl TlsListElement for NewSessionTicketExtension { |
2413 | const SIZE_LEN: ListLength = ListLength::U16; |
2414 | } |
2415 | |
2416 | #[derive (Debug)] |
2417 | pub struct NewSessionTicketPayloadTls13 { |
2418 | pub(crate) lifetime: u32, |
2419 | pub(crate) age_add: u32, |
2420 | pub(crate) nonce: PayloadU8, |
2421 | pub(crate) ticket: Arc<PayloadU16>, |
2422 | pub(crate) exts: Vec<NewSessionTicketExtension>, |
2423 | } |
2424 | |
2425 | impl NewSessionTicketPayloadTls13 { |
2426 | pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self { |
2427 | Self { |
2428 | lifetime, |
2429 | age_add, |
2430 | nonce: PayloadU8::new(nonce), |
2431 | ticket: Arc::new(PayloadU16::new(ticket)), |
2432 | exts: vec![], |
2433 | } |
2434 | } |
2435 | |
2436 | pub(crate) fn has_duplicate_extension(&self) -> bool { |
2437 | has_duplicates::<_, _, u16>( |
2438 | self.exts |
2439 | .iter() |
2440 | .map(|ext| ext.ext_type()), |
2441 | ) |
2442 | } |
2443 | |
2444 | pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> { |
2445 | self.exts |
2446 | .iter() |
2447 | .find(|x| x.ext_type() == ext) |
2448 | } |
2449 | |
2450 | pub(crate) fn max_early_data_size(&self) -> Option<u32> { |
2451 | let ext = self.find_extension(ExtensionType::EarlyData)?; |
2452 | match ext { |
2453 | NewSessionTicketExtension::EarlyData(sz) => Some(*sz), |
2454 | _ => None, |
2455 | } |
2456 | } |
2457 | } |
2458 | |
2459 | impl Codec<'_> for NewSessionTicketPayloadTls13 { |
2460 | fn encode(&self, bytes: &mut Vec<u8>) { |
2461 | self.lifetime.encode(bytes); |
2462 | self.age_add.encode(bytes); |
2463 | self.nonce.encode(bytes); |
2464 | self.ticket.encode(bytes); |
2465 | self.exts.encode(bytes); |
2466 | } |
2467 | |
2468 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2469 | let lifetime = u32::read(r)?; |
2470 | let age_add = u32::read(r)?; |
2471 | let nonce = PayloadU8::read(r)?; |
2472 | let ticket = Arc::new(PayloadU16::read(r)?); |
2473 | let exts = Vec::read(r)?; |
2474 | |
2475 | Ok(Self { |
2476 | lifetime, |
2477 | age_add, |
2478 | nonce, |
2479 | ticket, |
2480 | exts, |
2481 | }) |
2482 | } |
2483 | } |
2484 | |
2485 | // -- RFC6066 certificate status types |
2486 | |
2487 | /// Only supports OCSP |
2488 | #[derive (Debug)] |
2489 | pub struct CertificateStatus<'a> { |
2490 | pub(crate) ocsp_response: PayloadU24<'a>, |
2491 | } |
2492 | |
2493 | impl<'a> Codec<'a> for CertificateStatus<'a> { |
2494 | fn encode(&self, bytes: &mut Vec<u8>) { |
2495 | CertificateStatusType::OCSP.encode(bytes); |
2496 | self.ocsp_response.encode(bytes); |
2497 | } |
2498 | |
2499 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
2500 | let typ: CertificateStatusType = CertificateStatusType::read(r)?; |
2501 | |
2502 | match typ { |
2503 | CertificateStatusType::OCSP => Ok(Self { |
2504 | ocsp_response: PayloadU24::read(r)?, |
2505 | }), |
2506 | _ => Err(InvalidMessage::InvalidCertificateStatusType), |
2507 | } |
2508 | } |
2509 | } |
2510 | |
2511 | impl<'a> CertificateStatus<'a> { |
2512 | pub(crate) fn new(ocsp: &'a [u8]) -> Self { |
2513 | CertificateStatus { |
2514 | ocsp_response: PayloadU24(Payload::Borrowed(ocsp)), |
2515 | } |
2516 | } |
2517 | |
2518 | #[cfg (feature = "tls12" )] |
2519 | pub(crate) fn into_inner(self) -> Vec<u8> { |
2520 | self.ocsp_response.0.into_vec() |
2521 | } |
2522 | |
2523 | pub(crate) fn into_owned(self) -> CertificateStatus<'static> { |
2524 | CertificateStatus { |
2525 | ocsp_response: self.ocsp_response.into_owned(), |
2526 | } |
2527 | } |
2528 | } |
2529 | |
2530 | // -- RFC8879 compressed certificates |
2531 | |
2532 | #[derive (Debug)] |
2533 | pub struct CompressedCertificatePayload<'a> { |
2534 | pub(crate) alg: CertificateCompressionAlgorithm, |
2535 | pub(crate) uncompressed_len: u32, |
2536 | pub(crate) compressed: PayloadU24<'a>, |
2537 | } |
2538 | |
2539 | impl<'a> Codec<'a> for CompressedCertificatePayload<'a> { |
2540 | fn encode(&self, bytes: &mut Vec<u8>) { |
2541 | self.alg.encode(bytes); |
2542 | codec::u24(self.uncompressed_len).encode(bytes); |
2543 | self.compressed.encode(bytes); |
2544 | } |
2545 | |
2546 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
2547 | Ok(Self { |
2548 | alg: CertificateCompressionAlgorithm::read(r)?, |
2549 | uncompressed_len: codec::u24::read(r)?.0, |
2550 | compressed: PayloadU24::read(r)?, |
2551 | }) |
2552 | } |
2553 | } |
2554 | |
2555 | impl CompressedCertificatePayload<'_> { |
2556 | fn into_owned(self) -> CompressedCertificatePayload<'static> { |
2557 | CompressedCertificatePayload { |
2558 | compressed: self.compressed.into_owned(), |
2559 | ..self |
2560 | } |
2561 | } |
2562 | |
2563 | pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> { |
2564 | CompressedCertificatePayload { |
2565 | alg: self.alg, |
2566 | uncompressed_len: self.uncompressed_len, |
2567 | compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())), |
2568 | } |
2569 | } |
2570 | } |
2571 | |
2572 | #[derive (Debug)] |
2573 | pub enum HandshakePayload<'a> { |
2574 | HelloRequest, |
2575 | ClientHello(ClientHelloPayload), |
2576 | ServerHello(ServerHelloPayload), |
2577 | HelloRetryRequest(HelloRetryRequest), |
2578 | Certificate(CertificateChain<'a>), |
2579 | CertificateTls13(CertificatePayloadTls13<'a>), |
2580 | CompressedCertificate(CompressedCertificatePayload<'a>), |
2581 | ServerKeyExchange(ServerKeyExchangePayload), |
2582 | CertificateRequest(CertificateRequestPayload), |
2583 | CertificateRequestTls13(CertificateRequestPayloadTls13), |
2584 | CertificateVerify(DigitallySignedStruct), |
2585 | ServerHelloDone, |
2586 | EndOfEarlyData, |
2587 | ClientKeyExchange(Payload<'a>), |
2588 | NewSessionTicket(NewSessionTicketPayload), |
2589 | NewSessionTicketTls13(NewSessionTicketPayloadTls13), |
2590 | EncryptedExtensions(Vec<ServerExtension>), |
2591 | KeyUpdate(KeyUpdateRequest), |
2592 | Finished(Payload<'a>), |
2593 | CertificateStatus(CertificateStatus<'a>), |
2594 | MessageHash(Payload<'a>), |
2595 | Unknown(Payload<'a>), |
2596 | } |
2597 | |
2598 | impl HandshakePayload<'_> { |
2599 | fn encode(&self, bytes: &mut Vec<u8>) { |
2600 | use self::HandshakePayload::*; |
2601 | match self { |
2602 | HelloRequest | ServerHelloDone | EndOfEarlyData => {} |
2603 | ClientHello(x) => x.encode(bytes), |
2604 | ServerHello(x) => x.encode(bytes), |
2605 | HelloRetryRequest(x) => x.encode(bytes), |
2606 | Certificate(x) => x.encode(bytes), |
2607 | CertificateTls13(x) => x.encode(bytes), |
2608 | CompressedCertificate(x) => x.encode(bytes), |
2609 | ServerKeyExchange(x) => x.encode(bytes), |
2610 | ClientKeyExchange(x) => x.encode(bytes), |
2611 | CertificateRequest(x) => x.encode(bytes), |
2612 | CertificateRequestTls13(x) => x.encode(bytes), |
2613 | CertificateVerify(x) => x.encode(bytes), |
2614 | NewSessionTicket(x) => x.encode(bytes), |
2615 | NewSessionTicketTls13(x) => x.encode(bytes), |
2616 | EncryptedExtensions(x) => x.encode(bytes), |
2617 | KeyUpdate(x) => x.encode(bytes), |
2618 | Finished(x) => x.encode(bytes), |
2619 | CertificateStatus(x) => x.encode(bytes), |
2620 | MessageHash(x) => x.encode(bytes), |
2621 | Unknown(x) => x.encode(bytes), |
2622 | } |
2623 | } |
2624 | |
2625 | fn into_owned(self) -> HandshakePayload<'static> { |
2626 | use HandshakePayload::*; |
2627 | |
2628 | match self { |
2629 | HelloRequest => HelloRequest, |
2630 | ClientHello(x) => ClientHello(x), |
2631 | ServerHello(x) => ServerHello(x), |
2632 | HelloRetryRequest(x) => HelloRetryRequest(x), |
2633 | Certificate(x) => Certificate(x.into_owned()), |
2634 | CertificateTls13(x) => CertificateTls13(x.into_owned()), |
2635 | CompressedCertificate(x) => CompressedCertificate(x.into_owned()), |
2636 | ServerKeyExchange(x) => ServerKeyExchange(x), |
2637 | CertificateRequest(x) => CertificateRequest(x), |
2638 | CertificateRequestTls13(x) => CertificateRequestTls13(x), |
2639 | CertificateVerify(x) => CertificateVerify(x), |
2640 | ServerHelloDone => ServerHelloDone, |
2641 | EndOfEarlyData => EndOfEarlyData, |
2642 | ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()), |
2643 | NewSessionTicket(x) => NewSessionTicket(x), |
2644 | NewSessionTicketTls13(x) => NewSessionTicketTls13(x), |
2645 | EncryptedExtensions(x) => EncryptedExtensions(x), |
2646 | KeyUpdate(x) => KeyUpdate(x), |
2647 | Finished(x) => Finished(x.into_owned()), |
2648 | CertificateStatus(x) => CertificateStatus(x.into_owned()), |
2649 | MessageHash(x) => MessageHash(x.into_owned()), |
2650 | Unknown(x) => Unknown(x.into_owned()), |
2651 | } |
2652 | } |
2653 | } |
2654 | |
2655 | #[derive (Debug)] |
2656 | pub struct HandshakeMessagePayload<'a> { |
2657 | pub typ: HandshakeType, |
2658 | pub payload: HandshakePayload<'a>, |
2659 | } |
2660 | |
2661 | impl<'a> Codec<'a> for HandshakeMessagePayload<'a> { |
2662 | fn encode(&self, bytes: &mut Vec<u8>) { |
2663 | self.payload_encode(bytes, Encoding::Standard); |
2664 | } |
2665 | |
2666 | fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> { |
2667 | Self::read_version(r, vers:ProtocolVersion::TLSv1_2) |
2668 | } |
2669 | } |
2670 | |
2671 | impl<'a> HandshakeMessagePayload<'a> { |
2672 | pub(crate) fn read_version( |
2673 | r: &mut Reader<'a>, |
2674 | vers: ProtocolVersion, |
2675 | ) -> Result<Self, InvalidMessage> { |
2676 | let mut typ = HandshakeType::read(r)?; |
2677 | let len = codec::u24::read(r)?.0 as usize; |
2678 | let mut sub = r.sub(len)?; |
2679 | |
2680 | let payload = match typ { |
2681 | HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest, |
2682 | HandshakeType::ClientHello => { |
2683 | HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?) |
2684 | } |
2685 | HandshakeType::ServerHello => { |
2686 | let version = ProtocolVersion::read(&mut sub)?; |
2687 | let random = Random::read(&mut sub)?; |
2688 | |
2689 | if random == HELLO_RETRY_REQUEST_RANDOM { |
2690 | let mut hrr = HelloRetryRequest::read(&mut sub)?; |
2691 | hrr.legacy_version = version; |
2692 | typ = HandshakeType::HelloRetryRequest; |
2693 | HandshakePayload::HelloRetryRequest(hrr) |
2694 | } else { |
2695 | let mut shp = ServerHelloPayload::read(&mut sub)?; |
2696 | shp.legacy_version = version; |
2697 | shp.random = random; |
2698 | HandshakePayload::ServerHello(shp) |
2699 | } |
2700 | } |
2701 | HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => { |
2702 | let p = CertificatePayloadTls13::read(&mut sub)?; |
2703 | HandshakePayload::CertificateTls13(p) |
2704 | } |
2705 | HandshakeType::Certificate => { |
2706 | HandshakePayload::Certificate(CertificateChain::read(&mut sub)?) |
2707 | } |
2708 | HandshakeType::ServerKeyExchange => { |
2709 | let p = ServerKeyExchangePayload::read(&mut sub)?; |
2710 | HandshakePayload::ServerKeyExchange(p) |
2711 | } |
2712 | HandshakeType::ServerHelloDone => { |
2713 | sub.expect_empty("ServerHelloDone" )?; |
2714 | HandshakePayload::ServerHelloDone |
2715 | } |
2716 | HandshakeType::ClientKeyExchange => { |
2717 | HandshakePayload::ClientKeyExchange(Payload::read(&mut sub)) |
2718 | } |
2719 | HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => { |
2720 | let p = CertificateRequestPayloadTls13::read(&mut sub)?; |
2721 | HandshakePayload::CertificateRequestTls13(p) |
2722 | } |
2723 | HandshakeType::CertificateRequest => { |
2724 | let p = CertificateRequestPayload::read(&mut sub)?; |
2725 | HandshakePayload::CertificateRequest(p) |
2726 | } |
2727 | HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate( |
2728 | CompressedCertificatePayload::read(&mut sub)?, |
2729 | ), |
2730 | HandshakeType::CertificateVerify => { |
2731 | HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?) |
2732 | } |
2733 | HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => { |
2734 | let p = NewSessionTicketPayloadTls13::read(&mut sub)?; |
2735 | HandshakePayload::NewSessionTicketTls13(p) |
2736 | } |
2737 | HandshakeType::NewSessionTicket => { |
2738 | let p = NewSessionTicketPayload::read(&mut sub)?; |
2739 | HandshakePayload::NewSessionTicket(p) |
2740 | } |
2741 | HandshakeType::EncryptedExtensions => { |
2742 | HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?) |
2743 | } |
2744 | HandshakeType::KeyUpdate => { |
2745 | HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?) |
2746 | } |
2747 | HandshakeType::EndOfEarlyData => { |
2748 | sub.expect_empty("EndOfEarlyData" )?; |
2749 | HandshakePayload::EndOfEarlyData |
2750 | } |
2751 | HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)), |
2752 | HandshakeType::CertificateStatus => { |
2753 | HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?) |
2754 | } |
2755 | HandshakeType::MessageHash => { |
2756 | // does not appear on the wire |
2757 | return Err(InvalidMessage::UnexpectedMessage("MessageHash" )); |
2758 | } |
2759 | HandshakeType::HelloRetryRequest => { |
2760 | // not legal on wire |
2761 | return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest" )); |
2762 | } |
2763 | _ => HandshakePayload::Unknown(Payload::read(&mut sub)), |
2764 | }; |
2765 | |
2766 | sub.expect_empty("HandshakeMessagePayload" ) |
2767 | .map(|_| Self { typ, payload }) |
2768 | } |
2769 | |
2770 | pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> { |
2771 | let mut ret = self.get_encoding(); |
2772 | let ret_len = ret.len() - self.total_binder_length(); |
2773 | ret.truncate(ret_len); |
2774 | ret |
2775 | } |
2776 | |
2777 | pub(crate) fn total_binder_length(&self) -> usize { |
2778 | match &self.payload { |
2779 | HandshakePayload::ClientHello(ch) => match ch.extensions.last() { |
2780 | Some(ClientExtension::PresharedKey(offer)) => { |
2781 | let mut binders_encoding = Vec::new(); |
2782 | offer |
2783 | .binders |
2784 | .encode(&mut binders_encoding); |
2785 | binders_encoding.len() |
2786 | } |
2787 | _ => 0, |
2788 | }, |
2789 | _ => 0, |
2790 | } |
2791 | } |
2792 | |
2793 | pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) { |
2794 | // output type, length, and encoded payload |
2795 | match self.typ { |
2796 | HandshakeType::HelloRetryRequest => HandshakeType::ServerHello, |
2797 | _ => self.typ, |
2798 | } |
2799 | .encode(bytes); |
2800 | |
2801 | let nested = LengthPrefixedBuffer::new( |
2802 | ListLength::U24 { |
2803 | max: usize::MAX, |
2804 | error: InvalidMessage::MessageTooLarge, |
2805 | }, |
2806 | bytes, |
2807 | ); |
2808 | |
2809 | match &self.payload { |
2810 | // for Server Hello and HelloRetryRequest payloads we need to encode the payload |
2811 | // differently based on the purpose of the encoding. |
2812 | HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding), |
2813 | HandshakePayload::HelloRetryRequest(payload) => { |
2814 | payload.payload_encode(nested.buf, encoding) |
2815 | } |
2816 | |
2817 | // All other payload types are encoded the same regardless of purpose. |
2818 | _ => self.payload.encode(nested.buf), |
2819 | } |
2820 | } |
2821 | |
2822 | pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self { |
2823 | Self { |
2824 | typ: HandshakeType::MessageHash, |
2825 | payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())), |
2826 | } |
2827 | } |
2828 | |
2829 | pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> { |
2830 | let Self { typ, payload } = self; |
2831 | HandshakeMessagePayload { |
2832 | typ, |
2833 | payload: payload.into_owned(), |
2834 | } |
2835 | } |
2836 | } |
2837 | |
2838 | #[derive (Clone, Copy, Debug, Default, Eq, PartialEq)] |
2839 | pub struct HpkeSymmetricCipherSuite { |
2840 | pub kdf_id: HpkeKdf, |
2841 | pub aead_id: HpkeAead, |
2842 | } |
2843 | |
2844 | impl Codec<'_> for HpkeSymmetricCipherSuite { |
2845 | fn encode(&self, bytes: &mut Vec<u8>) { |
2846 | self.kdf_id.encode(bytes); |
2847 | self.aead_id.encode(bytes); |
2848 | } |
2849 | |
2850 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2851 | Ok(Self { |
2852 | kdf_id: HpkeKdf::read(r)?, |
2853 | aead_id: HpkeAead::read(r)?, |
2854 | }) |
2855 | } |
2856 | } |
2857 | |
2858 | impl TlsListElement for HpkeSymmetricCipherSuite { |
2859 | const SIZE_LEN: ListLength = ListLength::U16; |
2860 | } |
2861 | |
2862 | #[derive (Clone, Debug, PartialEq)] |
2863 | pub struct HpkeKeyConfig { |
2864 | pub config_id: u8, |
2865 | pub kem_id: HpkeKem, |
2866 | pub public_key: PayloadU16, |
2867 | pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>, |
2868 | } |
2869 | |
2870 | impl Codec<'_> for HpkeKeyConfig { |
2871 | fn encode(&self, bytes: &mut Vec<u8>) { |
2872 | self.config_id.encode(bytes); |
2873 | self.kem_id.encode(bytes); |
2874 | self.public_key.encode(bytes); |
2875 | self.symmetric_cipher_suites |
2876 | .encode(bytes); |
2877 | } |
2878 | |
2879 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2880 | Ok(Self { |
2881 | config_id: u8::read(r)?, |
2882 | kem_id: HpkeKem::read(r)?, |
2883 | public_key: PayloadU16::read(r)?, |
2884 | symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?, |
2885 | }) |
2886 | } |
2887 | } |
2888 | |
2889 | #[derive (Clone, Debug, PartialEq)] |
2890 | pub struct EchConfigContents { |
2891 | pub key_config: HpkeKeyConfig, |
2892 | pub maximum_name_length: u8, |
2893 | pub public_name: DnsName<'static>, |
2894 | pub extensions: Vec<EchConfigExtension>, |
2895 | } |
2896 | |
2897 | impl EchConfigContents { |
2898 | /// Returns true if there is more than one extension of a given |
2899 | /// type. |
2900 | pub(crate) fn has_duplicate_extension(&self) -> bool { |
2901 | has_duplicates::<_, _, u16>( |
2902 | self.extensions |
2903 | .iter() |
2904 | .map(|ext: &EchConfigExtension| ext.ext_type()), |
2905 | ) |
2906 | } |
2907 | |
2908 | /// Returns true if there is at least one mandatory unsupported extension. |
2909 | pub(crate) fn has_unknown_mandatory_extension(&self) -> bool { |
2910 | self.extensions |
2911 | .iter() |
2912 | // An extension is considered mandatory if the high bit of its type is set. |
2913 | .any(|ext: &EchConfigExtension| { |
2914 | matches!(ext.ext_type(), ExtensionType::Unknown(_)) |
2915 | && u16::from(ext.ext_type()) & 0x8000 != 0 |
2916 | }) |
2917 | } |
2918 | } |
2919 | |
2920 | impl Codec<'_> for EchConfigContents { |
2921 | fn encode(&self, bytes: &mut Vec<u8>) { |
2922 | self.key_config.encode(bytes); |
2923 | self.maximum_name_length.encode(bytes); |
2924 | let dns_name: &DnsName<'_> = &self.public_name.borrow(); |
2925 | PayloadU8::encode_slice(slice:dns_name.as_ref().as_ref(), bytes); |
2926 | self.extensions.encode(bytes); |
2927 | } |
2928 | |
2929 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2930 | Ok(Self { |
2931 | key_config: HpkeKeyConfig::read(r)?, |
2932 | maximum_name_length: u8::read(r)?, |
2933 | public_name: { |
2934 | DnsNameDnsName<'_>::try_from(PayloadU8::read(r)?.0.as_slice()) |
2935 | .map_err(|_| InvalidMessage::InvalidServerName)? |
2936 | .to_owned() |
2937 | }, |
2938 | extensions: Vec::read(r)?, |
2939 | }) |
2940 | } |
2941 | } |
2942 | |
2943 | /// An encrypted client hello (ECH) config. |
2944 | #[derive (Clone, Debug, PartialEq)] |
2945 | pub enum EchConfigPayload { |
2946 | /// A recognized V18 ECH configuration. |
2947 | V18(EchConfigContents), |
2948 | /// An unknown version ECH configuration. |
2949 | Unknown { |
2950 | version: EchVersion, |
2951 | contents: PayloadU16, |
2952 | }, |
2953 | } |
2954 | |
2955 | impl TlsListElement for EchConfigPayload { |
2956 | const SIZE_LEN: ListLength = ListLength::U16; |
2957 | } |
2958 | |
2959 | impl Codec<'_> for EchConfigPayload { |
2960 | fn encode(&self, bytes: &mut Vec<u8>) { |
2961 | match self { |
2962 | Self::V18(c) => { |
2963 | // Write the version, the length, and the contents. |
2964 | EchVersion::V18.encode(bytes); |
2965 | let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes); |
2966 | c.encode(inner.buf); |
2967 | } |
2968 | Self::Unknown { version, contents } => { |
2969 | // Unknown configuration versions are opaque. |
2970 | version.encode(bytes); |
2971 | contents.encode(bytes); |
2972 | } |
2973 | } |
2974 | } |
2975 | |
2976 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
2977 | let version = EchVersion::read(r)?; |
2978 | let length = u16::read(r)?; |
2979 | let mut contents = r.sub(length as usize)?; |
2980 | |
2981 | Ok(match version { |
2982 | EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?), |
2983 | _ => { |
2984 | // Note: we don't PayloadU16::read() here because we've already read the length prefix. |
2985 | let data = PayloadU16::new(contents.rest().into()); |
2986 | Self::Unknown { |
2987 | version, |
2988 | contents: data, |
2989 | } |
2990 | } |
2991 | }) |
2992 | } |
2993 | } |
2994 | |
2995 | #[derive (Clone, Debug, PartialEq)] |
2996 | pub enum EchConfigExtension { |
2997 | Unknown(UnknownExtension), |
2998 | } |
2999 | |
3000 | impl EchConfigExtension { |
3001 | pub(crate) fn ext_type(&self) -> ExtensionType { |
3002 | match self { |
3003 | Self::Unknown(r: &UnknownExtension) => r.typ, |
3004 | } |
3005 | } |
3006 | } |
3007 | |
3008 | impl Codec<'_> for EchConfigExtension { |
3009 | fn encode(&self, bytes: &mut Vec<u8>) { |
3010 | self.ext_type().encode(bytes); |
3011 | |
3012 | let nested: LengthPrefixedBuffer<'_> = LengthPrefixedBuffer::new(size_len:ListLength::U16, buf:bytes); |
3013 | match self { |
3014 | Self::Unknown(r: &UnknownExtension) => r.encode(bytes:nested.buf), |
3015 | } |
3016 | } |
3017 | |
3018 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
3019 | let typ: ExtensionType = ExtensionType::read(r)?; |
3020 | let len: usize = u16::read(r)? as usize; |
3021 | let mut sub: Reader<'_> = r.sub(length:len)?; |
3022 | |
3023 | #[allow (clippy::match_single_binding)] // Future-proofing. |
3024 | let ext: EchConfigExtension = match typ { |
3025 | _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)), |
3026 | }; |
3027 | |
3028 | sub.expect_empty("EchConfigExtension" ) |
3029 | .map(|_| ext) |
3030 | } |
3031 | } |
3032 | |
3033 | impl TlsListElement for EchConfigExtension { |
3034 | const SIZE_LEN: ListLength = ListLength::U16; |
3035 | } |
3036 | |
3037 | /// Representation of the `ECHClientHello` client extension specified in |
3038 | /// [draft-ietf-tls-esni Section 5]. |
3039 | /// |
3040 | /// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5> |
3041 | #[derive (Clone, Debug)] |
3042 | pub enum EncryptedClientHello { |
3043 | /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter]. |
3044 | Outer(EncryptedClientHelloOuter), |
3045 | /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner]. |
3046 | /// |
3047 | /// This variant has no payload. |
3048 | Inner, |
3049 | } |
3050 | |
3051 | impl Codec<'_> for EncryptedClientHello { |
3052 | fn encode(&self, bytes: &mut Vec<u8>) { |
3053 | match self { |
3054 | Self::Outer(payload: &EncryptedClientHelloOuter) => { |
3055 | EchClientHelloType::ClientHelloOuter.encode(bytes); |
3056 | payload.encode(bytes); |
3057 | } |
3058 | Self::Inner => { |
3059 | EchClientHelloType::ClientHelloInner.encode(bytes); |
3060 | // Empty payload. |
3061 | } |
3062 | } |
3063 | } |
3064 | |
3065 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
3066 | match EchClientHelloType::read(r)? { |
3067 | EchClientHelloType::ClientHelloOuter => { |
3068 | Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?)) |
3069 | } |
3070 | EchClientHelloType::ClientHelloInner => Ok(Self::Inner), |
3071 | _ => Err(InvalidMessage::InvalidContentType), |
3072 | } |
3073 | } |
3074 | } |
3075 | |
3076 | /// Representation of the ECHClientHello extension with type outer specified in |
3077 | /// [draft-ietf-tls-esni Section 5]. |
3078 | /// |
3079 | /// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5> |
3080 | #[derive (Clone, Debug)] |
3081 | pub struct EncryptedClientHelloOuter { |
3082 | /// The cipher suite used to encrypt ClientHelloInner. Must match a value from |
3083 | /// ECHConfigContents.cipher_suites list. |
3084 | pub cipher_suite: HpkeSymmetricCipherSuite, |
3085 | /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig. |
3086 | pub config_id: u8, |
3087 | /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field. |
3088 | /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest. |
3089 | pub enc: PayloadU16, |
3090 | /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE. |
3091 | pub payload: PayloadU16, |
3092 | } |
3093 | |
3094 | impl Codec<'_> for EncryptedClientHelloOuter { |
3095 | fn encode(&self, bytes: &mut Vec<u8>) { |
3096 | self.cipher_suite.encode(bytes); |
3097 | self.config_id.encode(bytes); |
3098 | self.enc.encode(bytes); |
3099 | self.payload.encode(bytes); |
3100 | } |
3101 | |
3102 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
3103 | Ok(Self { |
3104 | cipher_suite: HpkeSymmetricCipherSuite::read(r)?, |
3105 | config_id: u8::read(r)?, |
3106 | enc: PayloadU16::read(r)?, |
3107 | payload: PayloadU16::read(r)?, |
3108 | }) |
3109 | } |
3110 | } |
3111 | |
3112 | /// Representation of the ECHEncryptedExtensions extension specified in |
3113 | /// [draft-ietf-tls-esni Section 5]. |
3114 | /// |
3115 | /// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5> |
3116 | #[derive (Clone, Debug)] |
3117 | pub struct ServerEncryptedClientHello { |
3118 | pub(crate) retry_configs: Vec<EchConfigPayload>, |
3119 | } |
3120 | |
3121 | impl Codec<'_> for ServerEncryptedClientHello { |
3122 | fn encode(&self, bytes: &mut Vec<u8>) { |
3123 | self.retry_configs.encode(bytes); |
3124 | } |
3125 | |
3126 | fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { |
3127 | Ok(Self { |
3128 | retry_configs: Vec::<EchConfigPayload>::read(r)?, |
3129 | }) |
3130 | } |
3131 | } |
3132 | |
3133 | /// The method of encoding to use for a handshake message. |
3134 | /// |
3135 | /// In some cases a handshake message may be encoded differently depending on the purpose |
3136 | /// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded |
3137 | /// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation. |
3138 | pub(crate) enum Encoding { |
3139 | /// Standard RFC 8446 encoding. |
3140 | Standard, |
3141 | /// Encoding for ECH confirmation. |
3142 | EchConfirmation, |
3143 | /// Encoding for ECH inner client hello. |
3144 | EchInnerHello { to_compress: Vec<ExtensionType> }, |
3145 | } |
3146 | |
3147 | fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool { |
3148 | let mut seen: BTreeSet<{unknown}> = BTreeSet::new(); |
3149 | |
3150 | for x: E in iter { |
3151 | if !seen.insert(x.into()) { |
3152 | return true; |
3153 | } |
3154 | } |
3155 | |
3156 | false |
3157 | } |
3158 | |
3159 | #[cfg (test)] |
3160 | mod tests { |
3161 | use super::*; |
3162 | |
3163 | #[test ] |
3164 | fn test_ech_config_dupe_exts() { |
3165 | let unknown_ext = EchConfigExtension::Unknown(UnknownExtension { |
3166 | typ: ExtensionType::Unknown(0x42), |
3167 | payload: Payload::new(vec![0x42]), |
3168 | }); |
3169 | let mut config = config_template(); |
3170 | config |
3171 | .extensions |
3172 | .push(unknown_ext.clone()); |
3173 | config.extensions.push(unknown_ext); |
3174 | |
3175 | assert!(config.has_duplicate_extension()); |
3176 | assert!(!config.has_unknown_mandatory_extension()); |
3177 | } |
3178 | |
3179 | #[test ] |
3180 | fn test_ech_config_mandatory_exts() { |
3181 | let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension { |
3182 | typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set. |
3183 | payload: Payload::new(vec![0x42]), |
3184 | }); |
3185 | let mut config = config_template(); |
3186 | config |
3187 | .extensions |
3188 | .push(mandatory_unknown_ext); |
3189 | |
3190 | assert!(!config.has_duplicate_extension()); |
3191 | assert!(config.has_unknown_mandatory_extension()); |
3192 | } |
3193 | |
3194 | fn config_template() -> EchConfigContents { |
3195 | EchConfigContents { |
3196 | key_config: HpkeKeyConfig { |
3197 | config_id: 0, |
3198 | kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256, |
3199 | public_key: PayloadU16(b"xxx" .into()), |
3200 | symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite { |
3201 | kdf_id: HpkeKdf::HKDF_SHA256, |
3202 | aead_id: HpkeAead::AES_128_GCM, |
3203 | }], |
3204 | }, |
3205 | maximum_name_length: 0, |
3206 | public_name: DnsName::try_from("example.com" ).unwrap(), |
3207 | extensions: vec![], |
3208 | } |
3209 | } |
3210 | } |
3211 | |