1 | extern crate openssl; |
2 | extern crate openssl_probe; |
3 | |
4 | use self::openssl::error::ErrorStack; |
5 | use self::openssl::hash::MessageDigest; |
6 | use self::openssl::nid::Nid; |
7 | use self::openssl::pkcs12::Pkcs12; |
8 | use self::openssl::pkey::{PKey, Private}; |
9 | use self::openssl::ssl::{ |
10 | self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod, |
11 | SslVerifyMode, |
12 | }; |
13 | use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509}; |
14 | use std::error; |
15 | use std::fmt; |
16 | use std::io; |
17 | use std::sync::Once; |
18 | |
19 | use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; |
20 | |
21 | #[cfg (have_min_max_version)] |
22 | fn supported_protocols( |
23 | min: Option<Protocol>, |
24 | max: Option<Protocol>, |
25 | ctx: &mut SslContextBuilder, |
26 | ) -> Result<(), ErrorStack> { |
27 | use self::openssl::ssl::SslVersion; |
28 | |
29 | fn cvt(p: Protocol) -> SslVersion { |
30 | match p { |
31 | Protocol::Sslv3 => SslVersion::SSL3, |
32 | Protocol::Tlsv10 => SslVersion::TLS1, |
33 | Protocol::Tlsv11 => SslVersion::TLS1_1, |
34 | Protocol::Tlsv12 => SslVersion::TLS1_2, |
35 | Protocol::__NonExhaustive => unreachable!(), |
36 | } |
37 | } |
38 | |
39 | ctx.set_min_proto_version(min.map(cvt))?; |
40 | ctx.set_max_proto_version(max.map(cvt))?; |
41 | |
42 | Ok(()) |
43 | } |
44 | |
45 | #[cfg (not(have_min_max_version))] |
46 | fn supported_protocols( |
47 | min: Option<Protocol>, |
48 | max: Option<Protocol>, |
49 | ctx: &mut SslContextBuilder, |
50 | ) -> Result<(), ErrorStack> { |
51 | use self::openssl::ssl::SslOptions; |
52 | |
53 | let no_ssl_mask = SslOptions::NO_SSLV2 |
54 | | SslOptions::NO_SSLV3 |
55 | | SslOptions::NO_TLSV1 |
56 | | SslOptions::NO_TLSV1_1 |
57 | | SslOptions::NO_TLSV1_2; |
58 | |
59 | ctx.clear_options(no_ssl_mask); |
60 | let mut options = SslOptions::empty(); |
61 | options |= match min { |
62 | None => SslOptions::empty(), |
63 | Some(Protocol::Sslv3) => SslOptions::NO_SSLV2, |
64 | Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3, |
65 | Some(Protocol::Tlsv11) => { |
66 | SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1 |
67 | } |
68 | Some(Protocol::Tlsv12) => { |
69 | SslOptions::NO_SSLV2 |
70 | | SslOptions::NO_SSLV3 |
71 | | SslOptions::NO_TLSV1 |
72 | | SslOptions::NO_TLSV1_1 |
73 | } |
74 | Some(Protocol::__NonExhaustive) => unreachable!(), |
75 | }; |
76 | options |= match max { |
77 | None | Some(Protocol::Tlsv12) => SslOptions::empty(), |
78 | Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2, |
79 | Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2, |
80 | Some(Protocol::Sslv3) => { |
81 | SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2 |
82 | } |
83 | Some(Protocol::__NonExhaustive) => unreachable!(), |
84 | }; |
85 | |
86 | ctx.set_options(options); |
87 | |
88 | Ok(()) |
89 | } |
90 | |
91 | fn init_trust() { |
92 | static ONCE: Once = Once::new(); |
93 | ONCE.call_once(openssl_probe::init_ssl_cert_env_vars); |
94 | } |
95 | |
96 | #[cfg (target_os = "android" )] |
97 | fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> { |
98 | use std::fs; |
99 | |
100 | if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts" ) { |
101 | let certs = dir |
102 | .filter_map(|r| r.ok()) |
103 | .filter_map(|e| fs::read(e.path()).ok()) |
104 | .filter_map(|b| X509::from_pem(&b).ok()); |
105 | for cert in certs { |
106 | if let Err(err) = connector.cert_store_mut().add_cert(cert) { |
107 | debug!("load_android_root_certs error: {:?}" , err); |
108 | } |
109 | } |
110 | } |
111 | |
112 | Ok(()) |
113 | } |
114 | |
115 | #[derive (Debug)] |
116 | pub enum Error { |
117 | Normal(ErrorStack), |
118 | Ssl(ssl::Error, X509VerifyResult), |
119 | EmptyChain, |
120 | NotPkcs8, |
121 | } |
122 | |
123 | impl error::Error for Error { |
124 | fn source(&self) -> Option<&(dyn error::Error + 'static)> { |
125 | match *self { |
126 | Error::Normal(ref e: &ErrorStack) => error::Error::source(self:e), |
127 | Error::Ssl(ref e: &Error, _) => error::Error::source(self:e), |
128 | Error::EmptyChain => None, |
129 | Error::NotPkcs8 => None, |
130 | } |
131 | } |
132 | } |
133 | |
134 | impl fmt::Display for Error { |
135 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
136 | match *self { |
137 | Error::Normal(ref e: &ErrorStack) => fmt::Display::fmt(self:e, f:fmt), |
138 | Error::Ssl(ref e: &Error, X509VerifyResult::OK) => fmt::Display::fmt(self:e, f:fmt), |
139 | Error::Ssl(ref e: &Error, v: X509VerifyResult) => write!(fmt, " {} ( {})" , e, v), |
140 | Error::EmptyChain => write!( |
141 | fmt, |
142 | "at least one certificate must be provided to create an identity" |
143 | ), |
144 | Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM" ), |
145 | } |
146 | } |
147 | } |
148 | |
149 | impl From<ErrorStack> for Error { |
150 | fn from(err: ErrorStack) -> Error { |
151 | Error::Normal(err) |
152 | } |
153 | } |
154 | |
155 | #[derive (Clone)] |
156 | pub struct Identity { |
157 | pkey: PKey<Private>, |
158 | cert: X509, |
159 | chain: Vec<X509>, |
160 | } |
161 | |
162 | impl Identity { |
163 | pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> { |
164 | let pkcs12 = Pkcs12::from_der(buf)?; |
165 | let parsed = pkcs12.parse(pass)?; |
166 | Ok(Identity { |
167 | pkey: parsed.pkey, |
168 | cert: parsed.cert, |
169 | // > The stack is the reverse of what you might expect due to the way |
170 | // > PKCS12_parse is implemented, so we need to load it backwards. |
171 | // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44 |
172 | chain: parsed.chain.into_iter().flatten().rev().collect(), |
173 | }) |
174 | } |
175 | |
176 | pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> { |
177 | if !key.starts_with(b"-----BEGIN PRIVATE KEY-----" ) { |
178 | return Err(Error::NotPkcs8); |
179 | } |
180 | |
181 | let pkey = PKey::private_key_from_pem(key)?; |
182 | let mut cert_chain = X509::stack_from_pem(buf)?.into_iter(); |
183 | let cert = cert_chain.next().ok_or(Error::EmptyChain)?; |
184 | let chain = cert_chain.collect(); |
185 | Ok(Identity { pkey, cert, chain }) |
186 | } |
187 | } |
188 | |
189 | #[derive (Clone)] |
190 | pub struct Certificate(X509); |
191 | |
192 | impl Certificate { |
193 | pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> { |
194 | let cert: X509 = X509::from_der(buf)?; |
195 | Ok(Certificate(cert)) |
196 | } |
197 | |
198 | pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> { |
199 | let cert: X509 = X509::from_pem(buf)?; |
200 | Ok(Certificate(cert)) |
201 | } |
202 | |
203 | pub fn to_der(&self) -> Result<Vec<u8>, Error> { |
204 | let der: Vec = self.0.to_der()?; |
205 | Ok(der) |
206 | } |
207 | } |
208 | |
209 | pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>); |
210 | |
211 | impl<S> fmt::Debug for MidHandshakeTlsStream<S> |
212 | where |
213 | S: fmt::Debug, |
214 | { |
215 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
216 | fmt::Debug::fmt(&self.0, f:fmt) |
217 | } |
218 | } |
219 | |
220 | impl<S> MidHandshakeTlsStream<S> { |
221 | pub fn get_ref(&self) -> &S { |
222 | self.0.get_ref() |
223 | } |
224 | |
225 | pub fn get_mut(&mut self) -> &mut S { |
226 | self.0.get_mut() |
227 | } |
228 | } |
229 | |
230 | impl<S> MidHandshakeTlsStream<S> |
231 | where |
232 | S: io::Read + io::Write, |
233 | { |
234 | pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> { |
235 | match self.0.handshake() { |
236 | Ok(s: SslStream) => Ok(TlsStream(s)), |
237 | Err(e: HandshakeError) => Err(e.into()), |
238 | } |
239 | } |
240 | } |
241 | |
242 | pub enum HandshakeError<S> { |
243 | Failure(Error), |
244 | WouldBlock(MidHandshakeTlsStream<S>), |
245 | } |
246 | |
247 | impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> { |
248 | fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> { |
249 | match e { |
250 | ssl::HandshakeError::SetupFailure(e: ErrorStack) => HandshakeError::Failure(e.into()), |
251 | ssl::HandshakeError::Failure(e: MidHandshakeSslStream) => { |
252 | let v: X509VerifyResult = e.ssl().verify_result(); |
253 | HandshakeError::Failure(Error::Ssl(e.into_error(), v)) |
254 | } |
255 | ssl::HandshakeError::WouldBlock(s: MidHandshakeSslStream) => { |
256 | HandshakeError::WouldBlock(MidHandshakeTlsStream(s)) |
257 | } |
258 | } |
259 | } |
260 | } |
261 | |
262 | impl<S> From<ErrorStack> for HandshakeError<S> { |
263 | fn from(e: ErrorStack) -> HandshakeError<S> { |
264 | HandshakeError::Failure(e.into()) |
265 | } |
266 | } |
267 | |
268 | #[derive (Clone)] |
269 | pub struct TlsConnector { |
270 | connector: SslConnector, |
271 | use_sni: bool, |
272 | accept_invalid_hostnames: bool, |
273 | accept_invalid_certs: bool, |
274 | } |
275 | |
276 | impl TlsConnector { |
277 | pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> { |
278 | init_trust(); |
279 | |
280 | let mut connector = SslConnector::builder(SslMethod::tls())?; |
281 | if let Some(ref identity) = builder.identity { |
282 | connector.set_certificate(&identity.0.cert)?; |
283 | connector.set_private_key(&identity.0.pkey)?; |
284 | for cert in identity.0.chain.iter() { |
285 | // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html |
286 | // specifies that "When sending a certificate chain, extra chain certificates are |
287 | // sent in order following the end entity certificate." |
288 | connector.add_extra_chain_cert(cert.to_owned())?; |
289 | } |
290 | } |
291 | supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?; |
292 | |
293 | if builder.disable_built_in_roots { |
294 | connector.set_cert_store(X509StoreBuilder::new()?.build()); |
295 | } |
296 | |
297 | for cert in &builder.root_certificates { |
298 | if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) { |
299 | debug!("add_cert error: {:?}" , err); |
300 | } |
301 | } |
302 | |
303 | #[cfg (feature = "alpn" )] |
304 | { |
305 | if !builder.alpn.is_empty() { |
306 | // Wire format is each alpn preceded by its length as a byte. |
307 | let mut alpn_wire_format = Vec::with_capacity( |
308 | builder |
309 | .alpn |
310 | .iter() |
311 | .map(|s| s.as_bytes().len()) |
312 | .sum::<usize>() |
313 | + builder.alpn.len(), |
314 | ); |
315 | for alpn in builder.alpn.iter().map(|s| s.as_bytes()) { |
316 | alpn_wire_format.push(alpn.len() as u8); |
317 | alpn_wire_format.extend(alpn); |
318 | } |
319 | connector.set_alpn_protos(&alpn_wire_format)?; |
320 | } |
321 | } |
322 | |
323 | #[cfg (target_os = "android" )] |
324 | load_android_root_certs(&mut connector)?; |
325 | |
326 | Ok(TlsConnector { |
327 | connector: connector.build(), |
328 | use_sni: builder.use_sni, |
329 | accept_invalid_hostnames: builder.accept_invalid_hostnames, |
330 | accept_invalid_certs: builder.accept_invalid_certs, |
331 | }) |
332 | } |
333 | |
334 | pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> |
335 | where |
336 | S: io::Read + io::Write, |
337 | { |
338 | let mut ssl = self |
339 | .connector |
340 | .configure()? |
341 | .use_server_name_indication(self.use_sni) |
342 | .verify_hostname(!self.accept_invalid_hostnames); |
343 | if self.accept_invalid_certs { |
344 | ssl.set_verify(SslVerifyMode::NONE); |
345 | } |
346 | |
347 | let s = ssl.connect(domain, stream)?; |
348 | Ok(TlsStream(s)) |
349 | } |
350 | } |
351 | |
352 | impl fmt::Debug for TlsConnector { |
353 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
354 | fmt&mut DebugStruct<'_, '_>.debug_struct("TlsConnector" ) |
355 | // n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted |
356 | .field("use_sni" , &self.use_sni) |
357 | .field("accept_invalid_hostnames" , &self.accept_invalid_hostnames) |
358 | .field(name:"accept_invalid_certs" , &self.accept_invalid_certs) |
359 | .finish() |
360 | } |
361 | } |
362 | |
363 | #[derive (Clone)] |
364 | pub struct TlsAcceptor(SslAcceptor); |
365 | |
366 | impl TlsAcceptor { |
367 | pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> { |
368 | let mut acceptor: SslAcceptorBuilder = SslAcceptor::mozilla_intermediate(method:SslMethod::tls())?; |
369 | acceptor.set_private_key(&builder.identity.0.pkey)?; |
370 | acceptor.set_certificate(&builder.identity.0.cert)?; |
371 | for cert: &X509 in builder.identity.0.chain.iter() { |
372 | // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html |
373 | // specifies that "When sending a certificate chain, extra chain certificates are |
374 | // sent in order following the end entity certificate." |
375 | acceptor.add_extra_chain_cert(cert.to_owned())?; |
376 | } |
377 | supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?; |
378 | |
379 | Ok(TlsAcceptor(acceptor.build())) |
380 | } |
381 | |
382 | pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> |
383 | where |
384 | S: io::Read + io::Write, |
385 | { |
386 | let s: SslStream = self.0.accept(stream)?; |
387 | Ok(TlsStream(s)) |
388 | } |
389 | } |
390 | |
391 | pub struct TlsStream<S>(ssl::SslStream<S>); |
392 | |
393 | impl<S: fmt::Debug> fmt::Debug for TlsStream<S> { |
394 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
395 | fmt::Debug::fmt(&self.0, f:fmt) |
396 | } |
397 | } |
398 | |
399 | impl<S> TlsStream<S> { |
400 | pub fn get_ref(&self) -> &S { |
401 | self.0.get_ref() |
402 | } |
403 | |
404 | pub fn get_mut(&mut self) -> &mut S { |
405 | self.0.get_mut() |
406 | } |
407 | } |
408 | |
409 | impl<S: io::Read + io::Write> TlsStream<S> { |
410 | pub fn buffered_read_size(&self) -> Result<usize, Error> { |
411 | Ok(self.0.ssl().pending()) |
412 | } |
413 | |
414 | pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> { |
415 | Ok(self.0.ssl().peer_certificate().map(Certificate)) |
416 | } |
417 | |
418 | #[cfg (feature = "alpn" )] |
419 | pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> { |
420 | Ok(self |
421 | .0 |
422 | .ssl() |
423 | .selected_alpn_protocol() |
424 | .map(|alpn| alpn.to_vec())) |
425 | } |
426 | |
427 | pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> { |
428 | let cert = if self.0.ssl().is_server() { |
429 | self.0.ssl().certificate().map(|x| x.to_owned()) |
430 | } else { |
431 | self.0.ssl().peer_certificate() |
432 | }; |
433 | |
434 | let cert = match cert { |
435 | Some(cert) => cert, |
436 | None => return Ok(None), |
437 | }; |
438 | |
439 | let algo_nid = cert.signature_algorithm().object().nid(); |
440 | let signature_algorithms = match algo_nid.signature_algorithms() { |
441 | Some(algs) => algs, |
442 | None => return Ok(None), |
443 | }; |
444 | |
445 | let md = match signature_algorithms.digest { |
446 | Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(), |
447 | nid => match MessageDigest::from_nid(nid) { |
448 | Some(md) => md, |
449 | None => return Ok(None), |
450 | }, |
451 | }; |
452 | |
453 | let digest = cert.digest(md)?; |
454 | |
455 | Ok(Some(digest.to_vec())) |
456 | } |
457 | |
458 | pub fn shutdown(&mut self) -> io::Result<()> { |
459 | match self.0.shutdown() { |
460 | Ok(_) => Ok(()), |
461 | Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()), |
462 | Err(e) => Err(e |
463 | .into_io_error() |
464 | .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))), |
465 | } |
466 | } |
467 | } |
468 | |
469 | impl<S: io::Read + io::Write> io::Read for TlsStream<S> { |
470 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
471 | self.0.read(buf) |
472 | } |
473 | } |
474 | |
475 | impl<S: io::Read + io::Write> io::Write for TlsStream<S> { |
476 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
477 | self.0.write(buf) |
478 | } |
479 | |
480 | fn flush(&mut self) -> io::Result<()> { |
481 | self.0.flush() |
482 | } |
483 | } |
484 | |