1 | #![doc (html_root_url = "https://docs.rs/tokio-native-tls/0.3.0" )] |
2 | #![warn ( |
3 | missing_debug_implementations, |
4 | missing_docs, |
5 | rust_2018_idioms, |
6 | unreachable_pub |
7 | )] |
8 | #![deny (rustdoc::broken_intra_doc_links)] |
9 | #![doc (test( |
10 | no_crate_inject, |
11 | attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) |
12 | ))] |
13 | |
14 | //! Async TLS streams |
15 | //! |
16 | //! This library is an implementation of TLS streams using the most appropriate |
17 | //! system library by default for negotiating the connection. That is, on |
18 | //! Windows this library uses SChannel, on OSX it uses SecureTransport, and on |
19 | //! other platforms it uses OpenSSL. |
20 | //! |
21 | //! Each TLS stream implements the `Read` and `Write` traits to interact and |
22 | //! interoperate with the rest of the futures I/O ecosystem. Client connections |
23 | //! initiated from this crate verify hostnames automatically and by default. |
24 | //! |
25 | //! This crate primarily exports this ability through two newtypes, |
26 | //! `TlsConnector` and `TlsAcceptor`. These newtypes augment the |
27 | //! functionality provided by the `native-tls` crate, on which this crate is |
28 | //! built. Configuration of TLS parameters is still primarily done through the |
29 | //! `native-tls` crate. |
30 | |
31 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
32 | |
33 | use crate::native_tls::{Error, HandshakeError, MidHandshakeTlsStream}; |
34 | use std::fmt; |
35 | use std::future::Future; |
36 | use std::io::{self, Read, Write}; |
37 | use std::marker::Unpin; |
38 | #[cfg (unix)] |
39 | use std::os::unix::io::{AsRawFd, RawFd}; |
40 | #[cfg (windows)] |
41 | use std::os::windows::io::{AsRawSocket, RawSocket}; |
42 | use std::pin::Pin; |
43 | use std::ptr::null_mut; |
44 | use std::task::{Context, Poll}; |
45 | |
46 | /// An intermediate wrapper for the inner stream `S`. |
47 | #[derive (Debug)] |
48 | pub struct AllowStd<S> { |
49 | inner: S, |
50 | context: *mut (), |
51 | } |
52 | |
53 | impl<S> AllowStd<S> { |
54 | /// Returns a shared reference to the inner stream. |
55 | pub fn get_ref(&self) -> &S { |
56 | &self.inner |
57 | } |
58 | |
59 | /// Returns a mutable reference to the inner stream. |
60 | pub fn get_mut(&mut self) -> &mut S { |
61 | &mut self.inner |
62 | } |
63 | } |
64 | |
65 | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
66 | /// protocol. |
67 | /// |
68 | /// A `TlsStream<S>` represents a handshake that has been completed successfully |
69 | /// and both the server and the client are ready for receiving and sending |
70 | /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written |
71 | /// to a `TlsStream` are encrypted when passing through to `S`. |
72 | #[derive (Debug)] |
73 | pub struct TlsStream<S>(native_tls::TlsStream<AllowStd<S>>); |
74 | |
75 | /// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` |
76 | /// method. |
77 | #[derive (Clone)] |
78 | pub struct TlsConnector(native_tls::TlsConnector); |
79 | |
80 | /// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` |
81 | /// method. |
82 | #[derive (Clone)] |
83 | pub struct TlsAcceptor(native_tls::TlsAcceptor); |
84 | |
85 | struct MidHandshake<S>(Option<MidHandshakeTlsStream<AllowStd<S>>>); |
86 | |
87 | enum StartedHandshake<S> { |
88 | Done(TlsStream<S>), |
89 | Mid(MidHandshakeTlsStream<AllowStd<S>>), |
90 | } |
91 | |
92 | struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>); |
93 | struct StartedHandshakeFutureInner<F, S> { |
94 | f: F, |
95 | stream: S, |
96 | } |
97 | |
98 | struct Guard<'a, S>(&'a mut TlsStream<S>) |
99 | where |
100 | AllowStd<S>: Read + Write; |
101 | |
102 | impl<S> Drop for Guard<'_, S> |
103 | where |
104 | AllowStd<S>: Read + Write, |
105 | { |
106 | fn drop(&mut self) { |
107 | (self.0).0.get_mut().context = null_mut(); |
108 | } |
109 | } |
110 | |
111 | // *mut () context is neither Send nor Sync |
112 | unsafe impl<S: Send> Send for AllowStd<S> {} |
113 | unsafe impl<S: Sync> Sync for AllowStd<S> {} |
114 | |
115 | impl<S> AllowStd<S> |
116 | where |
117 | S: Unpin, |
118 | { |
119 | fn with_context<F, R>(&mut self, f: F) -> io::Result<R> |
120 | where |
121 | F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<io::Result<R>>, |
122 | { |
123 | unsafe { |
124 | assert!(!self.context.is_null()); |
125 | let waker: &mut Context<'_> = &mut *(self.context as *mut _); |
126 | match f(waker, Pin::new(&mut self.inner)) { |
127 | Poll::Ready(r: Result) => r, |
128 | Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), |
129 | } |
130 | } |
131 | } |
132 | } |
133 | |
134 | impl<S> Read for AllowStd<S> |
135 | where |
136 | S: AsyncRead + Unpin, |
137 | { |
138 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
139 | let mut buf: ReadBuf<'_> = ReadBuf::new(buf); |
140 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_read(cx:ctx, &mut buf))?; |
141 | Ok(buf.filled().len()) |
142 | } |
143 | } |
144 | |
145 | impl<S> Write for AllowStd<S> |
146 | where |
147 | S: AsyncWrite + Unpin, |
148 | { |
149 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
150 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_write(cx:ctx, buf)) |
151 | } |
152 | |
153 | fn flush(&mut self) -> io::Result<()> { |
154 | self.with_context(|ctx: &mut Context<'_>, stream: Pin<&mut S>| stream.poll_flush(cx:ctx)) |
155 | } |
156 | } |
157 | |
158 | impl<S> TlsStream<S> { |
159 | fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> Poll<io::Result<R>> |
160 | where |
161 | F: FnOnce(&mut native_tls::TlsStream<AllowStd<S>>) -> io::Result<R>, |
162 | AllowStd<S>: Read + Write, |
163 | { |
164 | self.0.get_mut().context = ctx as *mut _ as *mut (); |
165 | let g = Guard(self); |
166 | match f(&mut (g.0).0) { |
167 | Ok(v) => Poll::Ready(Ok(v)), |
168 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
169 | Err(e) => Poll::Ready(Err(e)), |
170 | } |
171 | } |
172 | |
173 | /// Returns a shared reference to the inner stream. |
174 | pub fn get_ref(&self) -> &native_tls::TlsStream<AllowStd<S>> { |
175 | &self.0 |
176 | } |
177 | |
178 | /// Returns a mutable reference to the inner stream. |
179 | pub fn get_mut(&mut self) -> &mut native_tls::TlsStream<AllowStd<S>> { |
180 | &mut self.0 |
181 | } |
182 | } |
183 | |
184 | impl<S> AsyncRead for TlsStream<S> |
185 | where |
186 | S: AsyncRead + AsyncWrite + Unpin, |
187 | { |
188 | fn poll_read( |
189 | mut self: Pin<&mut Self>, |
190 | ctx: &mut Context<'_>, |
191 | buf: &mut ReadBuf<'_>, |
192 | ) -> Poll<io::Result<()>> { |
193 | self.with_context(ctx, |s: &mut TlsStream>| { |
194 | let n: usize = s.read(buf:buf.initialize_unfilled())?; |
195 | buf.advance(n); |
196 | Ok(()) |
197 | }) |
198 | } |
199 | } |
200 | |
201 | impl<S> AsyncWrite for TlsStream<S> |
202 | where |
203 | S: AsyncRead + AsyncWrite + Unpin, |
204 | { |
205 | fn poll_write( |
206 | mut self: Pin<&mut Self>, |
207 | ctx: &mut Context<'_>, |
208 | buf: &[u8], |
209 | ) -> Poll<io::Result<usize>> { |
210 | self.with_context(ctx, |s: &mut TlsStream>| s.write(buf)) |
211 | } |
212 | |
213 | fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
214 | self.with_context(ctx, |s: &mut TlsStream>| s.flush()) |
215 | } |
216 | |
217 | fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { |
218 | self.with_context(ctx, |s: &mut TlsStream>| s.shutdown()) |
219 | } |
220 | } |
221 | |
222 | #[cfg (unix)] |
223 | impl<S> AsRawFd for TlsStream<S> |
224 | where |
225 | S: AsRawFd, |
226 | { |
227 | fn as_raw_fd(&self) -> RawFd { |
228 | self.get_ref().get_ref().get_ref().as_raw_fd() |
229 | } |
230 | } |
231 | |
232 | #[cfg (windows)] |
233 | impl<S> AsRawSocket for TlsStream<S> |
234 | where |
235 | S: AsRawSocket, |
236 | { |
237 | fn as_raw_socket(&self) -> RawSocket { |
238 | self.get_ref().get_ref().get_ref().as_raw_socket() |
239 | } |
240 | } |
241 | |
242 | async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error> |
243 | where |
244 | F: FnOnce( |
245 | AllowStd<S>, |
246 | ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
247 | + Unpin, |
248 | S: AsyncRead + AsyncWrite + Unpin, |
249 | { |
250 | let start: StartedHandshakeFuture = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream })); |
251 | |
252 | match start.await { |
253 | Err(e: Error) => Err(e), |
254 | Ok(StartedHandshake::Done(s: TlsStream)) => Ok(s), |
255 | Ok(StartedHandshake::Mid(s: MidHandshakeTlsStream>)) => MidHandshake(Some(s)).await, |
256 | } |
257 | } |
258 | |
259 | impl<F, S> Future for StartedHandshakeFuture<F, S> |
260 | where |
261 | F: FnOnce( |
262 | AllowStd<S>, |
263 | ) -> Result<native_tls::TlsStream<AllowStd<S>>, HandshakeError<AllowStd<S>>> |
264 | + Unpin, |
265 | S: Unpin, |
266 | AllowStd<S>: Read + Write, |
267 | { |
268 | type Output = Result<StartedHandshake<S>, Error>; |
269 | |
270 | fn poll( |
271 | mut self: Pin<&mut Self>, |
272 | ctx: &mut Context<'_>, |
273 | ) -> Poll<Result<StartedHandshake<S>, Error>> { |
274 | let inner = self.0.take().expect("future polled after completion" ); |
275 | let stream = AllowStd { |
276 | inner: inner.stream, |
277 | context: ctx as *mut _ as *mut (), |
278 | }; |
279 | |
280 | match (inner.f)(stream) { |
281 | Ok(mut s) => { |
282 | s.get_mut().context = null_mut(); |
283 | Poll::Ready(Ok(StartedHandshake::Done(TlsStream(s)))) |
284 | } |
285 | Err(HandshakeError::WouldBlock(mut s)) => { |
286 | s.get_mut().context = null_mut(); |
287 | Poll::Ready(Ok(StartedHandshake::Mid(s))) |
288 | } |
289 | Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), |
290 | } |
291 | } |
292 | } |
293 | |
294 | impl TlsConnector { |
295 | /// Connects the provided stream with this connector, assuming the provided |
296 | /// domain. |
297 | /// |
298 | /// This function will internally call `TlsConnector::connect` to connect |
299 | /// the stream and returns a future representing the resolution of the |
300 | /// connection operation. The returned future will resolve to either |
301 | /// `TlsStream<S>` or `Error` depending if it's successful or not. |
302 | /// |
303 | /// This is typically used for clients who have already established, for |
304 | /// example, a TCP connection to a remote server. That stream is then |
305 | /// provided here to perform the client half of a connection to a |
306 | /// TLS-powered server. |
307 | pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, Error> |
308 | where |
309 | S: AsyncRead + AsyncWrite + Unpin, |
310 | { |
311 | handshake(f:move |s: AllowStd| self.0.connect(domain, stream:s), stream).await |
312 | } |
313 | } |
314 | |
315 | impl fmt::Debug for TlsConnector { |
316 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
317 | f.debug_struct(name:"TlsConnector" ).finish() |
318 | } |
319 | } |
320 | |
321 | impl From<native_tls::TlsConnector> for TlsConnector { |
322 | fn from(inner: native_tls::TlsConnector) -> TlsConnector { |
323 | TlsConnector(inner) |
324 | } |
325 | } |
326 | |
327 | impl TlsAcceptor { |
328 | /// Accepts a new client connection with the provided stream. |
329 | /// |
330 | /// This function will internally call `TlsAcceptor::accept` to connect |
331 | /// the stream and returns a future representing the resolution of the |
332 | /// connection operation. The returned future will resolve to either |
333 | /// `TlsStream<S>` or `Error` depending if it's successful or not. |
334 | /// |
335 | /// This is typically used after a new socket has been accepted from a |
336 | /// `TcpListener`. That socket is then passed to this function to perform |
337 | /// the server half of accepting a client connection. |
338 | pub async fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, Error> |
339 | where |
340 | S: AsyncRead + AsyncWrite + Unpin, |
341 | { |
342 | handshake(f:move |s: AllowStd| self.0.accept(stream:s), stream).await |
343 | } |
344 | } |
345 | |
346 | impl fmt::Debug for TlsAcceptor { |
347 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
348 | f.debug_struct(name:"TlsAcceptor" ).finish() |
349 | } |
350 | } |
351 | |
352 | impl From<native_tls::TlsAcceptor> for TlsAcceptor { |
353 | fn from(inner: native_tls::TlsAcceptor) -> TlsAcceptor { |
354 | TlsAcceptor(inner) |
355 | } |
356 | } |
357 | |
358 | impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> { |
359 | type Output = Result<TlsStream<S>, Error>; |
360 | |
361 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
362 | let mut_self: &mut MidHandshake = self.get_mut(); |
363 | let mut s: MidHandshakeTlsStream> = mut_self.0.take().expect(msg:"future polled after completion" ); |
364 | |
365 | s.get_mut().context = cx as *mut _ as *mut (); |
366 | match s.handshake() { |
367 | Ok(mut s: TlsStream>) => { |
368 | s.get_mut().context = null_mut(); |
369 | Poll::Ready(Ok(TlsStream(s))) |
370 | } |
371 | Err(HandshakeError::WouldBlock(mut s: MidHandshakeTlsStream>)) => { |
372 | s.get_mut().context = null_mut(); |
373 | mut_self.0 = Some(s); |
374 | Poll::Pending |
375 | } |
376 | Err(HandshakeError::Failure(e: Error)) => Poll::Ready(Err(e)), |
377 | } |
378 | } |
379 | } |
380 | |
381 | /// re-export native_tls |
382 | pub mod native_tls { |
383 | pub use native_tls::*; |
384 | } |
385 | |