| 1 | #![doc (html_root_url = "https://docs.rs/tokio-native-tls/0.3.0" )] |
| 2 | #![warn ( |
| 3 | missing_debug_implementations, |
| 4 | missing_docs, |
| 5 | rust_2018_idioms, |
| 6 | unreachable_pub |
| 7 | )] |
| 8 | #![deny (rustdoc::broken_intra_doc_links)] |
| 9 | #![doc (test( |
| 10 | no_crate_inject, |
| 11 | attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) |
| 12 | ))] |
| 13 | |
| 14 | //! Async TLS streams |
| 15 | //! |
| 16 | //! This library is an implementation of TLS streams using the most appropriate |
| 17 | //! system library by default for negotiating the connection. That is, on |
| 18 | //! Windows this library uses SChannel, on OSX it uses SecureTransport, and on |
| 19 | //! other platforms it uses OpenSSL. |
| 20 | //! |
| 21 | //! Each TLS stream implements the `Read` and `Write` traits to interact and |
| 22 | //! interoperate with the rest of the futures I/O ecosystem. Client connections |
| 23 | //! initiated from this crate verify hostnames automatically and by default. |
| 24 | //! |
| 25 | //! This crate primarily exports this ability through two newtypes, |
| 26 | //! `TlsConnector` and `TlsAcceptor`. These newtypes augment the |
| 27 | //! functionality provided by the `native-tls` crate, on which this crate is |
| 28 | //! built. Configuration of TLS parameters is still primarily done through the |
| 29 | //! `native-tls` crate. |
| 30 | |
| 31 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| 32 | |
| 33 | use crate::native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; |
| 34 | use std::fmt; |
| 35 | use std::future::Future; |
| 36 | use std::io::{self, Read, Write}; |
| 37 | use std::marker::Unpin; |
| 38 | #[cfg (unix)] |
| 39 | use std::os::unix::io::{AsRawFd, RawFd}; |
| 40 | #[cfg (windows)] |
| 41 | use std::os::windows::io::{AsRawSocket, RawSocket}; |
| 42 | use std::pin::Pin; |
| 43 | use std::ptr::null_mut; |
| 44 | use std::task::{Context, Poll}; |
| 45 | |
| 46 | /// An intermediate wrapper for the inner stream `S`. |
| 47 | #[derive (Debug)] |
| 48 | pub struct AllowStd<S> { |
| 49 | inner: S, |
| 50 | context: *mut (), |
| 51 | } |
| 52 | |
| 53 | impl<S> AllowStd<S> { |
| 54 | /// Returns a shared reference to the inner stream. |
| 55 | pub fn get_ref(&self) -> &S { |
| 56 | &self.inner |
| 57 | } |
| 58 | |
| 59 | /// Returns a mutable reference to the inner stream. |
| 60 | pub fn get_mut(&mut self) -> &mut S { |
| 61 | &mut self.inner |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
| 66 | /// protocol. |
| 67 | /// |
| 68 | /// A `TlsStream<S>` represents a handshake that has been completed successfully |
| 69 | /// and both the server and the client are ready for receiving and sending |
| 70 | /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written |
| 71 | /// to a `TlsStream` are encrypted when passing through to `S`. |
| 72 | #[derive (Debug)] |
| 73 | pub struct TlsStream<S>(native_tls::TlsStream<AllowStd<S>>); |
| 74 | |
| 75 | /// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` |
| 76 | /// method. |
| 77 | #[derive (Clone)] |
| 78 | pub struct TlsConnector(native_tls::TlsConnector); |
| 79 | |
| 80 | /// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` |
| 81 | /// method. |
| 82 | #[derive (Clone)] |
| 83 | pub struct TlsAcceptor(native_tls::TlsAcceptor); |
| 84 | |
| 85 | struct MidHandshake<S>(Option<MidHandshakeTlsStream<AllowStd<S>>>); |
| 86 | |
| 87 | enum StartedHandshake<S> { |
| 88 | Done(TlsStream<S>), |
| 89 | Mid(MidHandshakeTlsStream<AllowStd<S>>), |
| 90 | } |
| 91 | |
| 92 | struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>); |
| 93 | struct StartedHandshakeFutureInner<F, S> { |
| 94 | f: F, |
| 95 | stream: S, |
| 96 | } |
| 97 | |
| 98 | struct Guard<'a, S>(&'a mut TlsStream<S>) |
| 99 | where |
| 100 | AllowStd<S>: Read + Write; |
| 101 | |
| 102 | impl<S> Drop for Guard<'_, S> |
| 103 | where |
| 104 | AllowStd<S>: Read + Write, |
| 105 | { |
| 106 | fn drop(&mut self) { |
| 107 | (self.0).0.get_mut().context = null_mut(); |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | // *mut () context is neither Send nor Sync |
| 112 | unsafe impl<S: Send> Send for AllowStd<S> {} |
| 113 | unsafe impl<S: Sync> Sync for AllowStd<S> {} |
| 114 | |
| 115 | impl<S> AllowStd<S> |
| 116 | where |
| 117 | S: Unpin, |
| 118 | { |
| 119 | fn with_context<F, R>(&mut self, f: F) -> io::Result<R> |
| 120 | where |
| 121 | F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>, |
| 122 | { |
| 123 | unsafe { |
| 124 | assert!(!self.context.is_null()); |
| 125 | let waker: &mut () = &mut *(self.context as *mut _); |
| 126 | match f(waker, Pin::new(&mut self.inner)) { |
| 127 | Poll::Ready(r: Result) => r, |
| 128 | Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
| 129 | } |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | |
| 134 | impl<S> Read for AllowStd<S> |
| 135 | where |
| 136 | S: AsyncRead + Unpin, |
| 137 | { |
| 138 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| 139 | let mut buf: ReadBuf<'_> = ReadBuf::new(buf); |
| 140 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_read(cx:ctx, &mut buf))?; |
| 141 | Ok(buf.filled().len()) |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | impl<S> Write for AllowStd<S> |
| 146 | where |
| 147 | S: AsyncWrite + Unpin, |
| 148 | { |
| 149 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
| 150 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_write(cx:ctx, buf)) |
| 151 | } |
| 152 | |
| 153 | fn flush(&mut self) -> io::Result<()> { |
| 154 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_flush(cx:ctx)) |
| 155 | } |
| 156 | } |
| 157 | |
| 158 | impl<S> TlsStream<S> { |
| 159 | fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>> |
| 160 | where |
| 161 | F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>, |
| 162 | AllowStd<S>: Read + Write, |
| 163 | { |
| 164 | self.0.get_mut().context = ctx as *mut _ as *mut (); |
| 165 | let g = Guard(self); |
| 166 | match f(&mut (g.0).0) { |
| 167 | Ok(v) => Poll::Ready(Ok(v)), |
| 168 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
| 169 | Err(e) => Poll::Ready(Err(e)), |
| 170 | } |
| 171 | } |
| 172 | |
| 173 | /// Returns a shared reference to the inner stream. |
| 174 | pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> { |
| 175 | &self.0 |
| 176 | } |
| 177 | |
| 178 | /// Returns a mutable reference to the inner stream. |
| 179 | pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> { |
| 180 | &mut self.0 |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | impl<S> AsyncRead for TlsStream<S> |
| 185 | where |
| 186 | S: AsyncRead + AsyncWrite + Unpin, |
| 187 | { |
| 188 | fn poll_read( |
| 189 | mut self: Pin<&mut Self>, |
| 190 | ctx: &mut Context<'_>, |
| 191 | buf: &mut ReadBuf<'_>, |
| 192 | ) -> Poll<io::Result<()>> { |
| 193 | self.with_context(ctx, |s: &mut TlsStream>| { |
| 194 | let n: usize = s.read(buf.initialize_unfilled())?; |
| 195 | buf.advance(n); |
| 196 | Ok(()) |
| 197 | }) |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | impl<S> AsyncWrite for TlsStream<S> |
| 202 | where |
| 203 | S: AsyncRead + AsyncWrite + Unpin, |
| 204 | { |
| 205 | fn poll_write( |
| 206 | mut self: Pin<&mut Self>, |
| 207 | ctx: &mut Context<'_>, |
| 208 | buf: &[u8], |
| 209 | ) -> Poll<io::Result<usize>> { |
| 210 | self.with_context(ctx, |s: &mut TlsStream>| s.write(buf)) |
| 211 | } |
| 212 | |
| 213 | fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 214 | self.with_context(ctx, |s: &mut TlsStream>| s.flush()) |
| 215 | } |
| 216 | |
| 217 | fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 218 | self.with_context(ctx, |s: &mut TlsStream>| s.shutdown()) |
| 219 | } |
| 220 | } |
| 221 | |
| 222 | #[cfg (unix)] |
| 223 | impl<S> AsRawFd for TlsStream<S> |
| 224 | where |
| 225 | S: AsRawFd, |
| 226 | { |
| 227 | fn as_raw_fd(&self) -> RawFd { |
| 228 | self.get_ref().get_ref().get_ref().as_raw_fd() |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | #[cfg (windows)] |
| 233 | impl<S> AsRawSocket for TlsStream<S> |
| 234 | where |
| 235 | S: AsRawSocket, |
| 236 | { |
| 237 | fn as_raw_socket(&self) -> RawSocket { |
| 238 | self.get_ref().get_ref().get_ref().as_raw_socket() |
| 239 | } |
| 240 | } |
| 241 | |
| 242 | async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error> |
| 243 | where |
| 244 | F: FnOnce( |
| 245 | AllowStd<S>, |
| 246 | ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
| 247 | + Unpin, |
| 248 | S: AsyncRead + AsyncWrite + Unpin, |
| 249 | { |
| 250 | let start: StartedHandshakeFuture = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); |
| 251 | |
| 252 | match start.await { |
| 253 | Err(e: Error) => Err(e), |
| 254 | Ok(StartedHandshake::Done(s: TlsStream)) => Ok(s), |
| 255 | Ok(StartedHandshake::Mid(s: MidHandshakeTlsStream>)) => MidHandshake(Some(s)).await, |
| 256 | } |
| 257 | } |
| 258 | |
| 259 | impl<F, S> Future for StartedHandshakeFuture<F, S> |
| 260 | where |
| 261 | F: FnOnce( |
| 262 | AllowStd<S>, |
| 263 | ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
| 264 | + Unpin, |
| 265 | S: Unpin, |
| 266 | AllowStd<S>: Read + Write, |
| 267 | { |
| 268 | type Output = Result<StartedHandshake<S>, Error>; |
| 269 | |
| 270 | fn poll( |
| 271 | mut self: Pin<&mut Self>, |
| 272 | ctx: &mut Context<'_>, |
| 273 | ) -> Poll<Result<StartedHandshake<S>, Error>> { |
| 274 | let inner = self.0.take().expect("future polled after completion" ); |
| 275 | let stream = AllowStd { |
| 276 | inner: inner.stream, |
| 277 | context: ctx as *mut _ as *mut (), |
| 278 | }; |
| 279 | |
| 280 | match (inner.f)(stream) { |
| 281 | Ok(mut s) => { |
| 282 | s.get_mut().context = null_mut(); |
| 283 | Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) |
| 284 | } |
| 285 | Err(HandshakeError::WouldBlock(mut s)) => { |
| 286 | s.get_mut().context = null_mut(); |
| 287 | Poll::Ready(Ok(StartedHandshake::Mid(s))) |
| 288 | } |
| 289 | Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), |
| 290 | } |
| 291 | } |
| 292 | } |
| 293 | |
| 294 | impl TlsConnector { |
| 295 | /// Connects the provided stream with this connector, assuming the provided |
| 296 | /// domain. |
| 297 | /// |
| 298 | /// This function will internally call `TlsConnector::connect` to connect |
| 299 | /// the stream and returns a future representing the resolution of the |
| 300 | /// connection operation. The returned future will resolve to either |
| 301 | /// `TlsStream<S>` or `Error` depending if it's successful or not. |
| 302 | /// |
| 303 | /// This is typically used for clients who have already established, for |
| 304 | /// example, a TCP connection to a remote server. That stream is then |
| 305 | /// provided here to perform the client half of a connection to a |
| 306 | /// TLS-powered server. |
| 307 | pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, Error> |
| 308 | where |
| 309 | S: AsyncRead + AsyncWrite + Unpin, |
| 310 | { |
| 311 | handshake(f:move |s: AllowStd| self.0.connect(domain, stream:s), stream).await |
| 312 | } |
| 313 | } |
| 314 | |
| 315 | impl fmt::Debug for TlsConnector { |
| 316 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 317 | f.debug_struct(name:"TlsConnector" ).finish() |
| 318 | } |
| 319 | } |
| 320 | |
| 321 | impl From<native_tls::TlsConnector> for TlsConnector { |
| 322 | fn from(inner: native_tls::TlsConnector) -> TlsConnector { |
| 323 | TlsConnector(inner) |
| 324 | } |
| 325 | } |
| 326 | |
| 327 | impl TlsAcceptor { |
| 328 | /// Accepts a new client connection with the provided stream. |
| 329 | /// |
| 330 | /// This function will internally call `TlsAcceptor::accept` to connect |
| 331 | /// the stream and returns a future representing the resolution of the |
| 332 | /// connection operation. The returned future will resolve to either |
| 333 | /// `TlsStream<S>` or `Error` depending if it's successful or not. |
| 334 | /// |
| 335 | /// This is typically used after a new socket has been accepted from a |
| 336 | /// `TcpListener`. That socket is then passed to this function to perform |
| 337 | /// the server half of accepting a client connection. |
| 338 | pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, Error> |
| 339 | where |
| 340 | S: AsyncRead + AsyncWrite + Unpin, |
| 341 | { |
| 342 | handshake(f:move |s: AllowStd| self.0.accept(stream:s), stream).await |
| 343 | } |
| 344 | } |
| 345 | |
| 346 | impl fmt::Debug for TlsAcceptor { |
| 347 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 348 | f.debug_struct(name:"TlsAcceptor" ).finish() |
| 349 | } |
| 350 | } |
| 351 | |
| 352 | impl From<native_tls::TlsAcceptor> for TlsAcceptor { |
| 353 | fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { |
| 354 | TlsAcceptor(inner) |
| 355 | } |
| 356 | } |
| 357 | |
| 358 | impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> { |
| 359 | type Output = Result<TlsStream<S>, Error>; |
| 360 | |
| 361 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 362 | let mut_self: &mut MidHandshake = self.get_mut(); |
| 363 | let mut s: MidHandshakeTlsStream> = mut_self.0.take().expect(msg:"future polled after completion" ); |
| 364 | |
| 365 | s.get_mut().context = cx as *mut _ as *mut (); |
| 366 | match s.handshake() { |
| 367 | Ok(mut s: TlsStream>) => { |
| 368 | s.get_mut().context = null_mut(); |
| 369 | Poll::Ready(Ok(TlsStream(s))) |
| 370 | } |
| 371 | Err(HandshakeError::WouldBlock(mut s: MidHandshakeTlsStream>)) => { |
| 372 | s.get_mut().context = null_mut(); |
| 373 | mut_self.0 = Some(s); |
| 374 | Poll::Pending |
| 375 | } |
| 376 | Err(HandshakeError::Failure(e: Error)) => Poll::Ready(Err(e)), |
| 377 | } |
| 378 | } |
| 379 | } |
| 380 | |
| 381 | /// re-export native_tls |
| 382 | pub mod native_tls { |
| 383 | pub use native_tls::*; |
| 384 | } |
| 385 | |