1 | use std::{ |
2 | future::Future, |
3 | pin::Pin, |
4 | task::{Context, Poll}, |
5 | }; |
6 | |
7 | use futures_core::ready; |
8 | use http::HeaderMap; |
9 | use http_body::{Body, Frame}; |
10 | use pin_project_lite::pin_project; |
11 | |
12 | pin_project! { |
13 | /// Adds trailers to a body. |
14 | /// |
15 | /// See [`BodyExt::with_trailers`] for more details. |
16 | pub struct WithTrailers<T, F> { |
17 | #[pin] |
18 | state: State<T, F>, |
19 | } |
20 | } |
21 | |
22 | impl<T, F> WithTrailers<T, F> { |
23 | pub(crate) fn new(body: T, trailers: F) -> Self { |
24 | Self { |
25 | state: State::PollBody { |
26 | body, |
27 | trailers: Some(trailers), |
28 | }, |
29 | } |
30 | } |
31 | } |
32 | |
33 | pin_project! { |
34 | #[project = StateProj] |
35 | enum State<T, F> { |
36 | PollBody { |
37 | #[pin] |
38 | body: T, |
39 | trailers: Option<F>, |
40 | }, |
41 | PollTrailers { |
42 | #[pin] |
43 | trailers: F, |
44 | prev_trailers: Option<HeaderMap>, |
45 | }, |
46 | Done, |
47 | } |
48 | } |
49 | |
50 | impl<T, F> Body for WithTrailers<T, F> |
51 | where |
52 | T: Body, |
53 | F: Future<Output = Option<Result<HeaderMap, T::Error>>>, |
54 | { |
55 | type Data = T::Data; |
56 | type Error = T::Error; |
57 | |
58 | fn poll_frame( |
59 | mut self: Pin<&mut Self>, |
60 | cx: &mut Context<'_>, |
61 | ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { |
62 | loop { |
63 | let mut this = self.as_mut().project(); |
64 | |
65 | match this.state.as_mut().project() { |
66 | StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) { |
67 | Some(frame) => match frame.into_trailers() { |
68 | Ok(prev_trailers) => { |
69 | let trailers = trailers.take().unwrap(); |
70 | this.state.set(State::PollTrailers { |
71 | trailers, |
72 | prev_trailers: Some(prev_trailers), |
73 | }); |
74 | } |
75 | Err(frame) => { |
76 | return Poll::Ready(Some(Ok(frame))); |
77 | } |
78 | }, |
79 | None => { |
80 | let trailers = trailers.take().unwrap(); |
81 | this.state.set(State::PollTrailers { |
82 | trailers, |
83 | prev_trailers: None, |
84 | }); |
85 | } |
86 | }, |
87 | StateProj::PollTrailers { |
88 | trailers, |
89 | prev_trailers, |
90 | } => { |
91 | let trailers = ready!(trailers.poll(cx)?); |
92 | match (trailers, prev_trailers.take()) { |
93 | (None, None) => return Poll::Ready(None), |
94 | (None, Some(trailers)) | (Some(trailers), None) => { |
95 | this.state.set(State::Done); |
96 | return Poll::Ready(Some(Ok(Frame::trailers(trailers)))); |
97 | } |
98 | (Some(new_trailers), Some(mut prev_trailers)) => { |
99 | prev_trailers.extend(new_trailers); |
100 | this.state.set(State::Done); |
101 | return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers)))); |
102 | } |
103 | } |
104 | } |
105 | StateProj::Done => { |
106 | return Poll::Ready(None); |
107 | } |
108 | } |
109 | } |
110 | } |
111 | |
112 | #[inline ] |
113 | fn size_hint(&self) -> http_body::SizeHint { |
114 | match &self.state { |
115 | State::PollBody { body, .. } => body.size_hint(), |
116 | State::PollTrailers { .. } | State::Done => Default::default(), |
117 | } |
118 | } |
119 | } |
120 | |
121 | #[cfg (test)] |
122 | mod tests { |
123 | use std::convert::Infallible; |
124 | |
125 | use bytes::Bytes; |
126 | use http::{HeaderName, HeaderValue}; |
127 | |
128 | use crate::{BodyExt, Empty, Full}; |
129 | |
130 | #[allow (unused_imports)] |
131 | use super::*; |
132 | |
133 | #[tokio::test] |
134 | async fn works() { |
135 | let mut trailers = HeaderMap::new(); |
136 | trailers.insert( |
137 | HeaderName::from_static("foo" ), |
138 | HeaderValue::from_static("bar" ), |
139 | ); |
140 | |
141 | let body = |
142 | Full::<Bytes>::from("hello" ).with_trailers(std::future::ready(Some( |
143 | Ok::<_, Infallible>(trailers.clone()), |
144 | ))); |
145 | |
146 | futures_util::pin_mut!(body); |
147 | let waker = futures_util::task::noop_waker(); |
148 | let mut cx = Context::from_waker(&waker); |
149 | |
150 | let data = unwrap_ready(body.as_mut().poll_frame(&mut cx)) |
151 | .unwrap() |
152 | .unwrap() |
153 | .into_data() |
154 | .unwrap(); |
155 | assert_eq!(data, "hello" ); |
156 | |
157 | let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) |
158 | .unwrap() |
159 | .unwrap() |
160 | .into_trailers() |
161 | .unwrap(); |
162 | assert_eq!(body_trailers, trailers); |
163 | |
164 | assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); |
165 | } |
166 | |
167 | #[tokio::test] |
168 | async fn merges_trailers() { |
169 | let mut trailers_1 = HeaderMap::new(); |
170 | trailers_1.insert( |
171 | HeaderName::from_static("foo" ), |
172 | HeaderValue::from_static("bar" ), |
173 | ); |
174 | |
175 | let mut trailers_2 = HeaderMap::new(); |
176 | trailers_2.insert( |
177 | HeaderName::from_static("baz" ), |
178 | HeaderValue::from_static("qux" ), |
179 | ); |
180 | |
181 | let body = Empty::<Bytes>::new() |
182 | .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( |
183 | trailers_1.clone(), |
184 | )))) |
185 | .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( |
186 | trailers_2.clone(), |
187 | )))); |
188 | |
189 | futures_util::pin_mut!(body); |
190 | let waker = futures_util::task::noop_waker(); |
191 | let mut cx = Context::from_waker(&waker); |
192 | |
193 | let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) |
194 | .unwrap() |
195 | .unwrap() |
196 | .into_trailers() |
197 | .unwrap(); |
198 | |
199 | let mut all_trailers = HeaderMap::new(); |
200 | all_trailers.extend(trailers_1); |
201 | all_trailers.extend(trailers_2); |
202 | assert_eq!(body_trailers, all_trailers); |
203 | |
204 | assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); |
205 | } |
206 | |
207 | fn unwrap_ready<T>(poll: Poll<T>) -> T { |
208 | match poll { |
209 | Poll::Ready(t) => t, |
210 | Poll::Pending => panic!("pending" ), |
211 | } |
212 | } |
213 | } |
214 | |