1extern crate openssl;
2extern crate openssl_probe;
3
4use self::openssl::error::ErrorStack;
5use self::openssl::hash::MessageDigest;
6use self::openssl::nid::Nid;
7use self::openssl::pkcs12::Pkcs12;
8use self::openssl::pkey::{PKey, Private};
9use self::openssl::ssl::{
10 self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11 SslVerifyMode,
12};
13use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
14use std::error;
15use std::fmt;
16use std::io;
17use std::sync::Once;
18
19use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
20
21#[cfg(have_min_max_version)]
22fn supported_protocols(
23 min: Option<Protocol>,
24 max: Option<Protocol>,
25 ctx: &mut SslContextBuilder,
26) -> Result<(), ErrorStack> {
27 use self::openssl::ssl::SslVersion;
28
29 fn cvt(p: Protocol) -> SslVersion {
30 match p {
31 Protocol::Sslv3 => SslVersion::SSL3,
32 Protocol::Tlsv10 => SslVersion::TLS1,
33 Protocol::Tlsv11 => SslVersion::TLS1_1,
34 Protocol::Tlsv12 => SslVersion::TLS1_2,
35 Protocol::__NonExhaustive => unreachable!(),
36 }
37 }
38
39 ctx.set_min_proto_version(min.map(cvt))?;
40 ctx.set_max_proto_version(max.map(cvt))?;
41
42 Ok(())
43}
44
45#[cfg(not(have_min_max_version))]
46fn supported_protocols(
47 min: Option<Protocol>,
48 max: Option<Protocol>,
49 ctx: &mut SslContextBuilder,
50) -> Result<(), ErrorStack> {
51 use self::openssl::ssl::SslOptions;
52
53 let no_ssl_mask = SslOptions::NO_SSLV2
54 | SslOptions::NO_SSLV3
55 | SslOptions::NO_TLSV1
56 | SslOptions::NO_TLSV1_1
57 | SslOptions::NO_TLSV1_2;
58
59 ctx.clear_options(no_ssl_mask);
60 let mut options = SslOptions::empty();
61 options |= match min {
62 None => SslOptions::empty(),
63 Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
64 Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
65 Some(Protocol::Tlsv11) => {
66 SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
67 }
68 Some(Protocol::Tlsv12) => {
69 SslOptions::NO_SSLV2
70 | SslOptions::NO_SSLV3
71 | SslOptions::NO_TLSV1
72 | SslOptions::NO_TLSV1_1
73 }
74 Some(Protocol::__NonExhaustive) => unreachable!(),
75 };
76 options |= match max {
77 None | Some(Protocol::Tlsv12) => SslOptions::empty(),
78 Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
79 Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
80 Some(Protocol::Sslv3) => {
81 SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
82 }
83 Some(Protocol::__NonExhaustive) => unreachable!(),
84 };
85
86 ctx.set_options(options);
87
88 Ok(())
89}
90
91fn init_trust() {
92 static ONCE: Once = Once::new();
93 ONCE.call_once(openssl_probe::init_ssl_cert_env_vars);
94}
95
96#[cfg(target_os = "android")]
97fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
98 use std::fs;
99
100 if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
101 let certs = dir
102 .filter_map(|r| r.ok())
103 .filter_map(|e| fs::read(e.path()).ok())
104 .filter_map(|b| X509::from_pem(&b).ok());
105 for cert in certs {
106 if let Err(err) = connector.cert_store_mut().add_cert(cert) {
107 debug!("load_android_root_certs error: {:?}", err);
108 }
109 }
110 }
111
112 Ok(())
113}
114
115#[derive(Debug)]
116pub enum Error {
117 Normal(ErrorStack),
118 Ssl(ssl::Error, X509VerifyResult),
119 EmptyChain,
120 NotPkcs8,
121}
122
123impl error::Error for Error {
124 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
125 match *self {
126 Error::Normal(ref e: &ErrorStack) => error::Error::source(self:e),
127 Error::Ssl(ref e: &Error, _) => error::Error::source(self:e),
128 Error::EmptyChain => None,
129 Error::NotPkcs8 => None,
130 }
131 }
132}
133
134impl fmt::Display for Error {
135 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
136 match *self {
137 Error::Normal(ref e: &ErrorStack) => fmt::Display::fmt(self:e, f:fmt),
138 Error::Ssl(ref e: &Error, X509VerifyResult::OK) => fmt::Display::fmt(self:e, f:fmt),
139 Error::Ssl(ref e: &Error, v: X509VerifyResult) => write!(fmt, "{} ({})", e, v),
140 Error::EmptyChain => write!(
141 fmt,
142 "at least one certificate must be provided to create an identity"
143 ),
144 Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
145 }
146 }
147}
148
149impl From<ErrorStack> for Error {
150 fn from(err: ErrorStack) -> Error {
151 Error::Normal(err)
152 }
153}
154
155#[derive(Clone)]
156pub struct Identity {
157 pkey: PKey<Private>,
158 cert: X509,
159 chain: Vec<X509>,
160}
161
162impl Identity {
163 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
164 let pkcs12 = Pkcs12::from_der(buf)?;
165 let parsed = pkcs12.parse(pass)?;
166 Ok(Identity {
167 pkey: parsed.pkey,
168 cert: parsed.cert,
169 // > The stack is the reverse of what you might expect due to the way
170 // > PKCS12_parse is implemented, so we need to load it backwards.
171 // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44
172 chain: parsed.chain.into_iter().flatten().rev().collect(),
173 })
174 }
175
176 pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
177 if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
178 return Err(Error::NotPkcs8);
179 }
180
181 let pkey = PKey::private_key_from_pem(key)?;
182 let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
183 let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
184 let chain = cert_chain.collect();
185 Ok(Identity { pkey, cert, chain })
186 }
187}
188
189#[derive(Clone)]
190pub struct Certificate(X509);
191
192impl Certificate {
193 pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
194 let cert: X509 = X509::from_der(buf)?;
195 Ok(Certificate(cert))
196 }
197
198 pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
199 let cert: X509 = X509::from_pem(buf)?;
200 Ok(Certificate(cert))
201 }
202
203 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
204 let der: Vec = self.0.to_der()?;
205 Ok(der)
206 }
207}
208
209pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
210
211impl<S> fmt::Debug for MidHandshakeTlsStream<S>
212where
213 S: fmt::Debug,
214{
215 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
216 fmt::Debug::fmt(&self.0, f:fmt)
217 }
218}
219
220impl<S> MidHandshakeTlsStream<S> {
221 pub fn get_ref(&self) -> &S {
222 self.0.get_ref()
223 }
224
225 pub fn get_mut(&mut self) -> &mut S {
226 self.0.get_mut()
227 }
228}
229
230impl<S> MidHandshakeTlsStream<S>
231where
232 S: io::Read + io::Write,
233{
234 pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
235 match self.0.handshake() {
236 Ok(s: SslStream) => Ok(TlsStream(s)),
237 Err(e: HandshakeError) => Err(e.into()),
238 }
239 }
240}
241
242pub enum HandshakeError<S> {
243 Failure(Error),
244 WouldBlock(MidHandshakeTlsStream<S>),
245}
246
247impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
248 fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
249 match e {
250 ssl::HandshakeError::SetupFailure(e: ErrorStack) => HandshakeError::Failure(e.into()),
251 ssl::HandshakeError::Failure(e: MidHandshakeSslStream) => {
252 let v: X509VerifyResult = e.ssl().verify_result();
253 HandshakeError::Failure(Error::Ssl(e.into_error(), v))
254 }
255 ssl::HandshakeError::WouldBlock(s: MidHandshakeSslStream) => {
256 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
257 }
258 }
259 }
260}
261
262impl<S> From<ErrorStack> for HandshakeError<S> {
263 fn from(e: ErrorStack) -> HandshakeError<S> {
264 HandshakeError::Failure(e.into())
265 }
266}
267
268#[derive(Clone)]
269pub struct TlsConnector {
270 connector: SslConnector,
271 use_sni: bool,
272 accept_invalid_hostnames: bool,
273 accept_invalid_certs: bool,
274}
275
276impl TlsConnector {
277 pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
278 init_trust();
279
280 let mut connector = SslConnector::builder(SslMethod::tls())?;
281 if let Some(ref identity) = builder.identity {
282 connector.set_certificate(&identity.0.cert)?;
283 connector.set_private_key(&identity.0.pkey)?;
284 for cert in identity.0.chain.iter() {
285 // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
286 // specifies that "When sending a certificate chain, extra chain certificates are
287 // sent in order following the end entity certificate."
288 connector.add_extra_chain_cert(cert.to_owned())?;
289 }
290 }
291 supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
292
293 if builder.disable_built_in_roots {
294 connector.set_cert_store(X509StoreBuilder::new()?.build());
295 }
296
297 for cert in &builder.root_certificates {
298 if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
299 debug!("add_cert error: {:?}", err);
300 }
301 }
302
303 #[cfg(feature = "alpn")]
304 {
305 if !builder.alpn.is_empty() {
306 // Wire format is each alpn preceded by its length as a byte.
307 let mut alpn_wire_format = Vec::with_capacity(
308 builder
309 .alpn
310 .iter()
311 .map(|s| s.as_bytes().len())
312 .sum::<usize>()
313 + builder.alpn.len(),
314 );
315 for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
316 alpn_wire_format.push(alpn.len() as u8);
317 alpn_wire_format.extend(alpn);
318 }
319 connector.set_alpn_protos(&alpn_wire_format)?;
320 }
321 }
322
323 #[cfg(target_os = "android")]
324 load_android_root_certs(&mut connector)?;
325
326 Ok(TlsConnector {
327 connector: connector.build(),
328 use_sni: builder.use_sni,
329 accept_invalid_hostnames: builder.accept_invalid_hostnames,
330 accept_invalid_certs: builder.accept_invalid_certs,
331 })
332 }
333
334 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
335 where
336 S: io::Read + io::Write,
337 {
338 let mut ssl = self
339 .connector
340 .configure()?
341 .use_server_name_indication(self.use_sni)
342 .verify_hostname(!self.accept_invalid_hostnames);
343 if self.accept_invalid_certs {
344 ssl.set_verify(SslVerifyMode::NONE);
345 }
346
347 let s = ssl.connect(domain, stream)?;
348 Ok(TlsStream(s))
349 }
350}
351
352impl fmt::Debug for TlsConnector {
353 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
354 fmt&mut DebugStruct<'_, '_>.debug_struct("TlsConnector")
355 // n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
356 .field("use_sni", &self.use_sni)
357 .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
358 .field(name:"accept_invalid_certs", &self.accept_invalid_certs)
359 .finish()
360 }
361}
362
363#[derive(Clone)]
364pub struct TlsAcceptor(SslAcceptor);
365
366impl TlsAcceptor {
367 pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
368 let mut acceptor: SslAcceptorBuilder = SslAcceptor::mozilla_intermediate(method:SslMethod::tls())?;
369 acceptor.set_private_key(&builder.identity.0.pkey)?;
370 acceptor.set_certificate(&builder.identity.0.cert)?;
371 for cert: &X509 in builder.identity.0.chain.iter() {
372 // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
373 // specifies that "When sending a certificate chain, extra chain certificates are
374 // sent in order following the end entity certificate."
375 acceptor.add_extra_chain_cert(cert.to_owned())?;
376 }
377 supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
378
379 Ok(TlsAcceptor(acceptor.build()))
380 }
381
382 pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
383 where
384 S: io::Read + io::Write,
385 {
386 let s: SslStream = self.0.accept(stream)?;
387 Ok(TlsStream(s))
388 }
389}
390
391pub struct TlsStream<S>(ssl::SslStream<S>);
392
393impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
394 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
395 fmt::Debug::fmt(&self.0, f:fmt)
396 }
397}
398
399impl<S> TlsStream<S> {
400 pub fn get_ref(&self) -> &S {
401 self.0.get_ref()
402 }
403
404 pub fn get_mut(&mut self) -> &mut S {
405 self.0.get_mut()
406 }
407}
408
409impl<S: io::Read + io::Write> TlsStream<S> {
410 pub fn buffered_read_size(&self) -> Result<usize, Error> {
411 Ok(self.0.ssl().pending())
412 }
413
414 pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
415 Ok(self.0.ssl().peer_certificate().map(Certificate))
416 }
417
418 #[cfg(feature = "alpn")]
419 pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
420 Ok(self
421 .0
422 .ssl()
423 .selected_alpn_protocol()
424 .map(|alpn| alpn.to_vec()))
425 }
426
427 pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
428 let cert = if self.0.ssl().is_server() {
429 self.0.ssl().certificate().map(|x| x.to_owned())
430 } else {
431 self.0.ssl().peer_certificate()
432 };
433
434 let cert = match cert {
435 Some(cert) => cert,
436 None => return Ok(None),
437 };
438
439 let algo_nid = cert.signature_algorithm().object().nid();
440 let signature_algorithms = match algo_nid.signature_algorithms() {
441 Some(algs) => algs,
442 None => return Ok(None),
443 };
444
445 let md = match signature_algorithms.digest {
446 Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
447 nid => match MessageDigest::from_nid(nid) {
448 Some(md) => md,
449 None => return Ok(None),
450 },
451 };
452
453 let digest = cert.digest(md)?;
454
455 Ok(Some(digest.to_vec()))
456 }
457
458 pub fn shutdown(&mut self) -> io::Result<()> {
459 match self.0.shutdown() {
460 Ok(_) => Ok(()),
461 Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
462 Err(e) => Err(e
463 .into_io_error()
464 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
465 }
466 }
467}
468
469impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
470 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
471 self.0.read(buf)
472 }
473}
474
475impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
476 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
477 self.0.write(buf)
478 }
479
480 fn flush(&mut self) -> io::Result<()> {
481 self.0.flush()
482 }
483}
484