1 | //! HTTP Upgrades |
2 | //! |
3 | //! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since |
4 | //! several concepts in HTTP allow for first talking HTTP, and then converting |
5 | //! to a different protocol, this module conflates them into a single API. |
6 | //! Those include: |
7 | //! |
8 | //! - HTTP/1.1 Upgrades |
9 | //! - HTTP `CONNECT` |
10 | //! |
11 | //! You are responsible for any other pre-requisites to establish an upgrade, |
12 | //! such as sending the appropriate headers, methods, and status codes. You can |
13 | //! then use [`on`][] to grab a `Future` which will resolve to the upgraded |
14 | //! connection object, or an error if the upgrade fails. |
15 | //! |
16 | //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism |
17 | //! |
18 | //! # Client |
19 | //! |
20 | //! Sending an HTTP upgrade from the [`client`](super::client) involves setting |
21 | //! either the appropriate method, if wanting to `CONNECT`, or headers such as |
22 | //! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the |
23 | //! `http::Response` back, you must check for the specific information that the |
24 | //! upgrade is agreed upon by the server (such as a `101` status code), and then |
25 | //! get the `Future` from the `Response`. |
26 | //! |
27 | //! # Server |
28 | //! |
29 | //! Receiving upgrade requests in a server requires you to check the relevant |
30 | //! headers in a `Request`, and if an upgrade should be done, you then send the |
31 | //! corresponding headers in a response. To then wait for hyper to finish the |
32 | //! upgrade, you call `on()` with the `Request`, and then can spawn a task |
33 | //! awaiting it. |
34 | //! |
35 | //! # Example |
36 | //! |
37 | //! See [this example][example] showing how upgrades work with both |
38 | //! Clients and Servers. |
39 | //! |
40 | //! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs |
41 | |
42 | use std::any::TypeId; |
43 | use std::error::Error as StdError; |
44 | use std::fmt; |
45 | use std::io; |
46 | use std::marker::Unpin; |
47 | |
48 | use bytes::Bytes; |
49 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
50 | use tokio::sync::oneshot; |
51 | #[cfg (any(feature = "http1" , feature = "http2" ))] |
52 | use tracing::trace; |
53 | |
54 | use crate::common::io::Rewind; |
55 | use crate::common::{task, Future, Pin, Poll}; |
56 | |
57 | /// An upgraded HTTP connection. |
58 | /// |
59 | /// This type holds a trait object internally of the original IO that |
60 | /// was used to speak HTTP before the upgrade. It can be used directly |
61 | /// as a `Read` or `Write` for convenience. |
62 | /// |
63 | /// Alternatively, if the exact type is known, this can be deconstructed |
64 | /// into its parts. |
65 | pub struct Upgraded { |
66 | io: Rewind<Box<dyn Io + Send>>, |
67 | } |
68 | |
69 | /// A future for a possible HTTP upgrade. |
70 | /// |
71 | /// If no upgrade was available, or it doesn't succeed, yields an `Error`. |
72 | pub struct OnUpgrade { |
73 | rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>, |
74 | } |
75 | |
76 | /// The deconstructed parts of an [`Upgraded`](Upgraded) type. |
77 | /// |
78 | /// Includes the original IO type, and a read buffer of bytes that the |
79 | /// HTTP state machine may have already read before completing an upgrade. |
80 | #[derive (Debug)] |
81 | pub struct Parts<T> { |
82 | /// The original IO object used before the upgrade. |
83 | pub io: T, |
84 | /// A buffer of bytes that have been read but not processed as HTTP. |
85 | /// |
86 | /// For instance, if the `Connection` is used for an HTTP upgrade request, |
87 | /// it is possible the server sent back the first bytes of the new protocol |
88 | /// along with the response upgrade. |
89 | /// |
90 | /// You will want to check for any existing bytes if you plan to continue |
91 | /// communicating on the IO object. |
92 | pub read_buf: Bytes, |
93 | _inner: (), |
94 | } |
95 | |
96 | /// Gets a pending HTTP upgrade from this message. |
97 | /// |
98 | /// This can be called on the following types: |
99 | /// |
100 | /// - `http::Request<B>` |
101 | /// - `http::Response<B>` |
102 | /// - `&mut http::Request<B>` |
103 | /// - `&mut http::Response<B>` |
104 | pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade { |
105 | msg.on_upgrade() |
106 | } |
107 | |
108 | #[cfg (any(feature = "http1" , feature = "http2" ))] |
109 | pub(super) struct Pending { |
110 | tx: oneshot::Sender<crate::Result<Upgraded>>, |
111 | } |
112 | |
113 | #[cfg (any(feature = "http1" , feature = "http2" ))] |
114 | pub(super) fn pending() -> (Pending, OnUpgrade) { |
115 | let (tx: Sender>, rx: Receiver>) = oneshot::channel(); |
116 | (Pending { tx }, OnUpgrade { rx: Some(rx) }) |
117 | } |
118 | |
119 | // ===== impl Upgraded ===== |
120 | |
121 | impl Upgraded { |
122 | #[cfg (any(feature = "http1" , feature = "http2" , test))] |
123 | pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self |
124 | where |
125 | T: AsyncRead + AsyncWrite + Unpin + Send + 'static, |
126 | { |
127 | Upgraded { |
128 | io: Rewind::new_buffered(Box::new(io), read_buf), |
129 | } |
130 | } |
131 | |
132 | /// Tries to downcast the internal trait object to the type passed. |
133 | /// |
134 | /// On success, returns the downcasted parts. On error, returns the |
135 | /// `Upgraded` back. |
136 | pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> { |
137 | let (io, buf) = self.io.into_inner(); |
138 | match io.__hyper_downcast() { |
139 | Ok(t) => Ok(Parts { |
140 | io: *t, |
141 | read_buf: buf, |
142 | _inner: (), |
143 | }), |
144 | Err(io) => Err(Upgraded { |
145 | io: Rewind::new_buffered(io, buf), |
146 | }), |
147 | } |
148 | } |
149 | } |
150 | |
151 | impl AsyncRead for Upgraded { |
152 | fn poll_read( |
153 | mut self: Pin<&mut Self>, |
154 | cx: &mut task::Context<'_>, |
155 | buf: &mut ReadBuf<'_>, |
156 | ) -> Poll<io::Result<()>> { |
157 | Pin::new(&mut self.io).poll_read(cx, buf) |
158 | } |
159 | } |
160 | |
161 | impl AsyncWrite for Upgraded { |
162 | fn poll_write( |
163 | mut self: Pin<&mut Self>, |
164 | cx: &mut task::Context<'_>, |
165 | buf: &[u8], |
166 | ) -> Poll<io::Result<usize>> { |
167 | Pin::new(&mut self.io).poll_write(cx, buf) |
168 | } |
169 | |
170 | fn poll_write_vectored( |
171 | mut self: Pin<&mut Self>, |
172 | cx: &mut task::Context<'_>, |
173 | bufs: &[io::IoSlice<'_>], |
174 | ) -> Poll<io::Result<usize>> { |
175 | Pin::new(&mut self.io).poll_write_vectored(cx, bufs) |
176 | } |
177 | |
178 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
179 | Pin::new(&mut self.io).poll_flush(cx) |
180 | } |
181 | |
182 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
183 | Pin::new(&mut self.io).poll_shutdown(cx) |
184 | } |
185 | |
186 | fn is_write_vectored(&self) -> bool { |
187 | self.io.is_write_vectored() |
188 | } |
189 | } |
190 | |
191 | impl fmt::Debug for Upgraded { |
192 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
193 | f.debug_struct(name:"Upgraded" ).finish() |
194 | } |
195 | } |
196 | |
197 | // ===== impl OnUpgrade ===== |
198 | |
199 | impl OnUpgrade { |
200 | pub(super) fn none() -> Self { |
201 | OnUpgrade { rx: None } |
202 | } |
203 | |
204 | #[cfg (feature = "http1" )] |
205 | pub(super) fn is_none(&self) -> bool { |
206 | self.rx.is_none() |
207 | } |
208 | } |
209 | |
210 | impl Future for OnUpgrade { |
211 | type Output = Result<Upgraded, crate::Error>; |
212 | |
213 | fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { |
214 | match self.rx { |
215 | Some(ref mut rx: &mut Receiver>) => Pin::new(pointer:rx).poll(cx).map(|res: Result, …>| match res { |
216 | Ok(Ok(upgraded: Upgraded)) => Ok(upgraded), |
217 | Ok(Err(err: Error)) => Err(err), |
218 | Err(_oneshot_canceled: RecvError) => Err(crate::Error::new_canceled().with(cause:UpgradeExpected)), |
219 | }), |
220 | None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())), |
221 | } |
222 | } |
223 | } |
224 | |
225 | impl fmt::Debug for OnUpgrade { |
226 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
227 | f.debug_struct(name:"OnUpgrade" ).finish() |
228 | } |
229 | } |
230 | |
231 | // ===== impl Pending ===== |
232 | |
233 | #[cfg (any(feature = "http1" , feature = "http2" ))] |
234 | impl Pending { |
235 | pub(super) fn fulfill(self, upgraded: Upgraded) { |
236 | trace!("pending upgrade fulfill" ); |
237 | let _ = self.tx.send(Ok(upgraded)); |
238 | } |
239 | |
240 | #[cfg (feature = "http1" )] |
241 | /// Don't fulfill the pending Upgrade, but instead signal that |
242 | /// upgrades are handled manually. |
243 | pub(super) fn manual(self) { |
244 | trace!("pending upgrade handled manually" ); |
245 | let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade())); |
246 | } |
247 | } |
248 | |
249 | // ===== impl UpgradeExpected ===== |
250 | |
251 | /// Error cause returned when an upgrade was expected but canceled |
252 | /// for whatever reason. |
253 | /// |
254 | /// This likely means the actual `Conn` future wasn't polled and upgraded. |
255 | #[derive (Debug)] |
256 | struct UpgradeExpected; |
257 | |
258 | impl fmt::Display for UpgradeExpected { |
259 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
260 | f.write_str(data:"upgrade expected but not completed" ) |
261 | } |
262 | } |
263 | |
264 | impl StdError for UpgradeExpected {} |
265 | |
266 | // ===== impl Io ===== |
267 | |
268 | pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { |
269 | fn __hyper_type_id(&self) -> TypeId { |
270 | TypeId::of::<Self>() |
271 | } |
272 | } |
273 | |
274 | impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {} |
275 | |
276 | impl dyn Io + Send { |
277 | fn __hyper_is<T: Io>(&self) -> bool { |
278 | let t: TypeId = TypeId::of::<T>(); |
279 | self.__hyper_type_id() == t |
280 | } |
281 | |
282 | fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> { |
283 | if self.__hyper_is::<T>() { |
284 | // Taken from `std::error::Error::downcast()`. |
285 | unsafe { |
286 | let raw: *mut dyn Io = Box::into_raw(self); |
287 | Ok(Box::from_raw(raw as *mut T)) |
288 | } |
289 | } else { |
290 | Err(self) |
291 | } |
292 | } |
293 | } |
294 | |
295 | mod sealed { |
296 | use super::OnUpgrade; |
297 | |
298 | pub trait CanUpgrade { |
299 | fn on_upgrade(self) -> OnUpgrade; |
300 | } |
301 | |
302 | impl<B> CanUpgrade for http::Request<B> { |
303 | fn on_upgrade(mut self) -> OnUpgrade { |
304 | self.extensions_mut() |
305 | .remove::<OnUpgrade>() |
306 | .unwrap_or_else(OnUpgrade::none) |
307 | } |
308 | } |
309 | |
310 | impl<B> CanUpgrade for &'_ mut http::Request<B> { |
311 | fn on_upgrade(self) -> OnUpgrade { |
312 | self.extensions_mut() |
313 | .remove::<OnUpgrade>() |
314 | .unwrap_or_else(OnUpgrade::none) |
315 | } |
316 | } |
317 | |
318 | impl<B> CanUpgrade for http::Response<B> { |
319 | fn on_upgrade(mut self) -> OnUpgrade { |
320 | self.extensions_mut() |
321 | .remove::<OnUpgrade>() |
322 | .unwrap_or_else(OnUpgrade::none) |
323 | } |
324 | } |
325 | |
326 | impl<B> CanUpgrade for &'_ mut http::Response<B> { |
327 | fn on_upgrade(self) -> OnUpgrade { |
328 | self.extensions_mut() |
329 | .remove::<OnUpgrade>() |
330 | .unwrap_or_else(OnUpgrade::none) |
331 | } |
332 | } |
333 | } |
334 | |
335 | #[cfg (test)] |
336 | mod tests { |
337 | use super::*; |
338 | |
339 | #[test ] |
340 | fn upgraded_downcast() { |
341 | let upgraded = Upgraded::new(Mock, Bytes::new()); |
342 | |
343 | let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err(); |
344 | |
345 | upgraded.downcast::<Mock>().unwrap(); |
346 | } |
347 | |
348 | // TODO: replace with tokio_test::io when it can test write_buf |
349 | struct Mock; |
350 | |
351 | impl AsyncRead for Mock { |
352 | fn poll_read( |
353 | self: Pin<&mut Self>, |
354 | _cx: &mut task::Context<'_>, |
355 | _buf: &mut ReadBuf<'_>, |
356 | ) -> Poll<io::Result<()>> { |
357 | unreachable!("Mock::poll_read" ) |
358 | } |
359 | } |
360 | |
361 | impl AsyncWrite for Mock { |
362 | fn poll_write( |
363 | self: Pin<&mut Self>, |
364 | _: &mut task::Context<'_>, |
365 | buf: &[u8], |
366 | ) -> Poll<io::Result<usize>> { |
367 | // panic!("poll_write shouldn't be called"); |
368 | Poll::Ready(Ok(buf.len())) |
369 | } |
370 | |
371 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
372 | unreachable!("Mock::poll_flush" ) |
373 | } |
374 | |
375 | fn poll_shutdown( |
376 | self: Pin<&mut Self>, |
377 | _cx: &mut task::Context<'_>, |
378 | ) -> Poll<io::Result<()>> { |
379 | unreachable!("Mock::poll_shutdown" ) |
380 | } |
381 | } |
382 | } |
383 | |