1#![doc(html_root_url = "https://docs.rs/tokio-native-tls/0.3.0")]
2#![warn(
3 missing_debug_implementations,
4 missing_docs,
5 rust_2018_idioms,
6 unreachable_pub
7)]
8#![deny(rustdoc::broken_intra_doc_links)]
9#![doc(test(
10 no_crate_inject,
11 attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
12))]
13
14//! Async TLS streams
15//!
16//! This library is an implementation of TLS streams using the most appropriate
17//! system library by default for negotiating the connection. That is, on
18//! Windows this library uses SChannel, on OSX it uses SecureTransport, and on
19//! other platforms it uses OpenSSL.
20//!
21//! Each TLS stream implements the `Read` and `Write` traits to interact and
22//! interoperate with the rest of the futures I/O ecosystem. Client connections
23//! initiated from this crate verify hostnames automatically and by default.
24//!
25//! This crate primarily exports this ability through two newtypes,
26//! `TlsConnector` and `TlsAcceptor`. These newtypes augment the
27//! functionality provided by the `native-tls` crate, on which this crate is
28//! built. Configuration of TLS parameters is still primarily done through the
29//! `native-tls` crate.
30
31use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
32
33use crate::native_tls::{Error, HandshakeError, MidHandshakeTlsStream};
34use std::fmt;
35use std::future::Future;
36use std::io::{self, Read, Write};
37use std::marker::Unpin;
38#[cfg(unix)]
39use std::os::unix::io::{AsRawFd, RawFd};
40#[cfg(windows)]
41use std::os::windows::io::{AsRawSocket, RawSocket};
42use std::pin::Pin;
43use std::ptr::null_mut;
44use std::task::{Context, Poll};
45
46/// An intermediate wrapper for the inner stream `S`.
47#[derive(Debug)]
48pub struct AllowStd<S> {
49 inner: S,
50 context: *mut (),
51}
52
53impl<S> AllowStd<S> {
54 /// Returns a shared reference to the inner stream.
55 pub fn get_ref(&self) -> &S {
56 &self.inner
57 }
58
59 /// Returns a mutable reference to the inner stream.
60 pub fn get_mut(&mut self) -> &mut S {
61 &mut self.inner
62 }
63}
64
65/// A wrapper around an underlying raw stream which implements the TLS or SSL
66/// protocol.
67///
68/// A `TlsStream<S>` represents a handshake that has been completed successfully
69/// and both the server and the client are ready for receiving and sending
70/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
71/// to a `TlsStream` are encrypted when passing through to `S`.
72#[derive(Debug)]
73pub struct TlsStream<S>(native_tls::TlsStream<AllowStd<S>>);
74
75/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect`
76/// method.
77#[derive(Clone)]
78pub struct TlsConnector(native_tls::TlsConnector);
79
80/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept`
81/// method.
82#[derive(Clone)]
83pub struct TlsAcceptor(native_tls::TlsAcceptor);
84
85struct MidHandshake<S>(Option<MidHandshakeTlsStream<AllowStd<S>>>);
86
87enum StartedHandshake<S> {
88 Done(TlsStream<S>),
89 Mid(MidHandshakeTlsStream<AllowStd<S>>),
90}
91
92struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
93struct StartedHandshakeFutureInner<F, S> {
94 f: F,
95 stream: S,
96}
97
98struct Guard<'a, S>(&'a mut TlsStream<S>)
99where
100 AllowStd<S>: Read + Write;
101
102impl<S> Drop for Guard<'_, S>
103where
104 AllowStd<S>: Read + Write,
105{
106 fn drop(&mut self) {
107 (self.0).0.get_mut().context = null_mut();
108 }
109}
110
111// *mut () context is neither Send nor Sync
112unsafe impl<S: Send> Send for AllowStd<S> {}
113unsafe impl<S: Sync> Sync for AllowStd<S> {}
114
115impl<S> AllowStd<S>
116where
117 S: Unpin,
118{
119 fn with_context<F, R>(&mut self, f: F) -> io::Result<R>
120 where
121 F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>,
122 {
123 unsafe {
124 assert!(!self.context.is_null());
125 let waker: &mut Context<'_> = &mut *(self.context as *mut _);
126 match f(waker, Pin::new(&mut self.inner)) {
127 Poll::Ready(r: Result) => r,
128 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
129 }
130 }
131 }
132}
133
134impl<S> Read for AllowStd<S>
135where
136 S: AsyncRead + Unpin,
137{
138 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
139 let mut buf: ReadBuf<'_> = ReadBuf::new(buf);
140 self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_read(cx:ctx, &mut buf))?;
141 Ok(buf.filled().len())
142 }
143}
144
145impl<S> Write for AllowStd<S>
146where
147 S: AsyncWrite + Unpin,
148{
149 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
150 self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_write(cx:ctx, buf))
151 }
152
153 fn flush(&mut self) -> io::Result<()> {
154 self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_flush(cx:ctx))
155 }
156}
157
158impl<S> TlsStream<S> {
159 fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>>
160 where
161 F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>,
162 AllowStd<S>: Read + Write,
163 {
164 self.0.get_mut().context = ctx as *mut _ as *mut ();
165 let g = Guard(self);
166 match f(&mut (g.0).0) {
167 Ok(v) => Poll::Ready(Ok(v)),
168 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
169 Err(e) => Poll::Ready(Err(e)),
170 }
171 }
172
173 /// Returns a shared reference to the inner stream.
174 pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> {
175 &self.0
176 }
177
178 /// Returns a mutable reference to the inner stream.
179 pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> {
180 &mut self.0
181 }
182}
183
184impl<S> AsyncRead for TlsStream<S>
185where
186 S: AsyncRead + AsyncWrite + Unpin,
187{
188 fn poll_read(
189 mut self: Pin<&mut Self>,
190 ctx: &mut Context<'_>,
191 buf: &mut ReadBuf<'_>,
192 ) -> Poll<io::Result<()>> {
193 self.with_context(ctx, |s: &mut TlsStream>| {
194 let n: usize = s.read(buf:buf.initialize_unfilled())?;
195 buf.advance(n);
196 Ok(())
197 })
198 }
199}
200
201impl<S> AsyncWrite for TlsStream<S>
202where
203 S: AsyncRead + AsyncWrite + Unpin,
204{
205 fn poll_write(
206 mut self: Pin<&mut Self>,
207 ctx: &mut Context<'_>,
208 buf: &[u8],
209 ) -> Poll<io::Result<usize>> {
210 self.with_context(ctx, |s: &mut TlsStream>| s.write(buf))
211 }
212
213 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
214 self.with_context(ctx, |s: &mut TlsStream>| s.flush())
215 }
216
217 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
218 self.with_context(ctx, |s: &mut TlsStream>| s.shutdown())
219 }
220}
221
222#[cfg(unix)]
223impl<S> AsRawFd for TlsStream<S>
224where
225 S: AsRawFd,
226{
227 fn as_raw_fd(&self) -> RawFd {
228 self.get_ref().get_ref().get_ref().as_raw_fd()
229 }
230}
231
232#[cfg(windows)]
233impl<S> AsRawSocket for TlsStream<S>
234where
235 S: AsRawSocket,
236{
237 fn as_raw_socket(&self) -> RawSocket {
238 self.get_ref().get_ref().get_ref().as_raw_socket()
239 }
240}
241
242async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
243where
244 F: FnOnce(
245 AllowStd<S>,
246 ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>>
247 + Unpin,
248 S: AsyncRead + AsyncWrite + Unpin,
249{
250 let start: StartedHandshakeFuture = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
251
252 match start.await {
253 Err(e: Error) => Err(e),
254 Ok(StartedHandshake::Done(s: TlsStream)) => Ok(s),
255 Ok(StartedHandshake::Mid(s: MidHandshakeTlsStream>)) => MidHandshake(Some(s)).await,
256 }
257}
258
259impl<F, S> Future for StartedHandshakeFuture<F, S>
260where
261 F: FnOnce(
262 AllowStd<S>,
263 ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>>
264 + Unpin,
265 S: Unpin,
266 AllowStd<S>: Read + Write,
267{
268 type Output = Result<StartedHandshake<S>, Error>;
269
270 fn poll(
271 mut self: Pin<&mut Self>,
272 ctx: &mut Context<'_>,
273 ) -> Poll<Result<StartedHandshake<S>, Error>> {
274 let inner = self.0.take().expect("future polled after completion");
275 let stream = AllowStd {
276 inner: inner.stream,
277 context: ctx as *mut _ as *mut (),
278 };
279
280 match (inner.f)(stream) {
281 Ok(mut s) => {
282 s.get_mut().context = null_mut();
283 Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s))))
284 }
285 Err(HandshakeError::WouldBlock(mut s)) => {
286 s.get_mut().context = null_mut();
287 Poll::Ready(Ok(StartedHandshake::Mid(s)))
288 }
289 Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
290 }
291 }
292}
293
294impl TlsConnector {
295 /// Connects the provided stream with this connector, assuming the provided
296 /// domain.
297 ///
298 /// This function will internally call `TlsConnector::connect` to connect
299 /// the stream and returns a future representing the resolution of the
300 /// connection operation. The returned future will resolve to either
301 /// `TlsStream<S>` or `Error` depending if it's successful or not.
302 ///
303 /// This is typically used for clients who have already established, for
304 /// example, a TCP connection to a remote server. That stream is then
305 /// provided here to perform the client half of a connection to a
306 /// TLS-powered server.
307 pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, Error>
308 where
309 S: AsyncRead + AsyncWrite + Unpin,
310 {
311 handshake(f:move |s: AllowStd| self.0.connect(domain, stream:s), stream).await
312 }
313}
314
315impl fmt::Debug for TlsConnector {
316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 f.debug_struct(name:"TlsConnector").finish()
318 }
319}
320
321impl From<native_tls::TlsConnector> for TlsConnector {
322 fn from(inner: native_tls::TlsConnector) -> TlsConnector {
323 TlsConnector(inner)
324 }
325}
326
327impl TlsAcceptor {
328 /// Accepts a new client connection with the provided stream.
329 ///
330 /// This function will internally call `TlsAcceptor::accept` to connect
331 /// the stream and returns a future representing the resolution of the
332 /// connection operation. The returned future will resolve to either
333 /// `TlsStream<S>` or `Error` depending if it's successful or not.
334 ///
335 /// This is typically used after a new socket has been accepted from a
336 /// `TcpListener`. That socket is then passed to this function to perform
337 /// the server half of accepting a client connection.
338 pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, Error>
339 where
340 S: AsyncRead + AsyncWrite + Unpin,
341 {
342 handshake(f:move |s: AllowStd| self.0.accept(stream:s), stream).await
343 }
344}
345
346impl fmt::Debug for TlsAcceptor {
347 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348 f.debug_struct(name:"TlsAcceptor").finish()
349 }
350}
351
352impl From<native_tls::TlsAcceptor> for TlsAcceptor {
353 fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor {
354 TlsAcceptor(inner)
355 }
356}
357
358impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
359 type Output = Result<TlsStream<S>, Error>;
360
361 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
362 let mut_self: &mut MidHandshake = self.get_mut();
363 let mut s: MidHandshakeTlsStream> = mut_self.0.take().expect(msg:"future polled after completion");
364
365 s.get_mut().context = cx as *mut _ as *mut ();
366 match s.handshake() {
367 Ok(mut s: TlsStream>) => {
368 s.get_mut().context = null_mut();
369 Poll::Ready(Ok(TlsStream(s)))
370 }
371 Err(HandshakeError::WouldBlock(mut s: MidHandshakeTlsStream>)) => {
372 s.get_mut().context = null_mut();
373 mut_self.0 = Some(s);
374 Poll::Pending
375 }
376 Err(HandshakeError::Failure(e: Error)) => Poll::Ready(Err(e)),
377 }
378 }
379}
380
381/// re-export native_tls
382pub mod native_tls {
383 pub use native_tls::*;
384}
385