1 | use crate::{Body, SizeHint}; |
2 | use bytes::Buf; |
3 | use http::HeaderMap; |
4 | use pin_project_lite::pin_project; |
5 | use std::error::Error; |
6 | use std::fmt; |
7 | use std::pin::Pin; |
8 | use std::task::{Context, Poll}; |
9 | |
10 | pin_project! { |
11 | /// A length limited body. |
12 | /// |
13 | /// This body will return an error if more than the configured number |
14 | /// of bytes are returned on polling the wrapped body. |
15 | #[derive(Clone, Copy, Debug)] |
16 | pub struct Limited<B> { |
17 | remaining: usize, |
18 | #[pin] |
19 | inner: B, |
20 | } |
21 | } |
22 | |
23 | impl<B> Limited<B> { |
24 | /// Create a new `Limited`. |
25 | pub fn new(inner: B, limit: usize) -> Self { |
26 | Self { |
27 | remaining: limit, |
28 | inner, |
29 | } |
30 | } |
31 | } |
32 | |
33 | impl<B> Body for Limited<B> |
34 | where |
35 | B: Body, |
36 | B::Error: Into<Box<dyn Error + Send + Sync>>, |
37 | { |
38 | type Data = B::Data; |
39 | type Error = Box<dyn Error + Send + Sync>; |
40 | |
41 | fn poll_data( |
42 | self: Pin<&mut Self>, |
43 | cx: &mut Context<'_>, |
44 | ) -> Poll<Option<Result<Self::Data, Self::Error>>> { |
45 | let this = self.project(); |
46 | let res = match this.inner.poll_data(cx) { |
47 | Poll::Pending => return Poll::Pending, |
48 | Poll::Ready(None) => None, |
49 | Poll::Ready(Some(Ok(data))) => { |
50 | if data.remaining() > *this.remaining { |
51 | *this.remaining = 0; |
52 | Some(Err(LengthLimitError.into())) |
53 | } else { |
54 | *this.remaining -= data.remaining(); |
55 | Some(Ok(data)) |
56 | } |
57 | } |
58 | Poll::Ready(Some(Err(err))) => Some(Err(err.into())), |
59 | }; |
60 | |
61 | Poll::Ready(res) |
62 | } |
63 | |
64 | fn poll_trailers( |
65 | self: Pin<&mut Self>, |
66 | cx: &mut Context<'_>, |
67 | ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { |
68 | let this = self.project(); |
69 | let res = match this.inner.poll_trailers(cx) { |
70 | Poll::Pending => return Poll::Pending, |
71 | Poll::Ready(Ok(data)) => Ok(data), |
72 | Poll::Ready(Err(err)) => Err(err.into()), |
73 | }; |
74 | |
75 | Poll::Ready(res) |
76 | } |
77 | |
78 | fn is_end_stream(&self) -> bool { |
79 | self.inner.is_end_stream() |
80 | } |
81 | |
82 | fn size_hint(&self) -> SizeHint { |
83 | use std::convert::TryFrom; |
84 | match u64::try_from(self.remaining) { |
85 | Ok(n) => { |
86 | let mut hint = self.inner.size_hint(); |
87 | if hint.lower() >= n { |
88 | hint.set_exact(n) |
89 | } else if let Some(max) = hint.upper() { |
90 | hint.set_upper(n.min(max)) |
91 | } else { |
92 | hint.set_upper(n) |
93 | } |
94 | hint |
95 | } |
96 | Err(_) => self.inner.size_hint(), |
97 | } |
98 | } |
99 | } |
100 | |
101 | /// An error returned when body length exceeds the configured limit. |
102 | #[derive (Debug)] |
103 | #[non_exhaustive ] |
104 | pub struct LengthLimitError; |
105 | |
106 | impl fmt::Display for LengthLimitError { |
107 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
108 | f.write_str(data:"length limit exceeded" ) |
109 | } |
110 | } |
111 | |
112 | impl Error for LengthLimitError {} |
113 | |
114 | #[cfg (test)] |
115 | mod tests { |
116 | use super::*; |
117 | use crate::Full; |
118 | use bytes::Bytes; |
119 | use std::convert::Infallible; |
120 | |
121 | #[tokio::test] |
122 | async fn read_for_body_under_limit_returns_data() { |
123 | const DATA: &[u8] = b"testing" ; |
124 | let inner = Full::new(Bytes::from(DATA)); |
125 | let body = &mut Limited::new(inner, 8); |
126 | |
127 | let mut hint = SizeHint::new(); |
128 | hint.set_upper(7); |
129 | assert_eq!(body.size_hint().upper(), hint.upper()); |
130 | |
131 | let data = body.data().await.unwrap().unwrap(); |
132 | assert_eq!(data, DATA); |
133 | hint.set_upper(0); |
134 | assert_eq!(body.size_hint().upper(), hint.upper()); |
135 | |
136 | assert!(matches!(body.data().await, None)); |
137 | } |
138 | |
139 | #[tokio::test] |
140 | async fn read_for_body_over_limit_returns_error() { |
141 | const DATA: &[u8] = b"testing a string that is too long" ; |
142 | let inner = Full::new(Bytes::from(DATA)); |
143 | let body = &mut Limited::new(inner, 8); |
144 | |
145 | let mut hint = SizeHint::new(); |
146 | hint.set_upper(8); |
147 | assert_eq!(body.size_hint().upper(), hint.upper()); |
148 | |
149 | let error = body.data().await.unwrap().unwrap_err(); |
150 | assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); |
151 | } |
152 | |
153 | struct Chunky(&'static [&'static [u8]]); |
154 | |
155 | impl Body for Chunky { |
156 | type Data = &'static [u8]; |
157 | type Error = Infallible; |
158 | |
159 | fn poll_data( |
160 | self: Pin<&mut Self>, |
161 | _cx: &mut Context<'_>, |
162 | ) -> Poll<Option<Result<Self::Data, Self::Error>>> { |
163 | let mut this = self; |
164 | match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) { |
165 | Some((data, new_tail)) => { |
166 | this.0 = new_tail; |
167 | |
168 | Poll::Ready(Some(data)) |
169 | } |
170 | None => Poll::Ready(None), |
171 | } |
172 | } |
173 | |
174 | fn poll_trailers( |
175 | self: Pin<&mut Self>, |
176 | _cx: &mut Context<'_>, |
177 | ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { |
178 | Poll::Ready(Ok(Some(HeaderMap::new()))) |
179 | } |
180 | } |
181 | |
182 | #[tokio::test] |
183 | async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( |
184 | ) { |
185 | const DATA: &[&[u8]] = &[b"testing " , b"a string that is too long" ]; |
186 | let inner = Chunky(DATA); |
187 | let body = &mut Limited::new(inner, 8); |
188 | |
189 | let mut hint = SizeHint::new(); |
190 | hint.set_upper(8); |
191 | assert_eq!(body.size_hint().upper(), hint.upper()); |
192 | |
193 | let data = body.data().await.unwrap().unwrap(); |
194 | assert_eq!(data, DATA[0]); |
195 | hint.set_upper(0); |
196 | assert_eq!(body.size_hint().upper(), hint.upper()); |
197 | |
198 | let error = body.data().await.unwrap().unwrap_err(); |
199 | assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); |
200 | } |
201 | |
202 | #[tokio::test] |
203 | async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { |
204 | const DATA: &[&[u8]] = &[b"testing a string" , b" that is too long" ]; |
205 | let inner = Chunky(DATA); |
206 | let body = &mut Limited::new(inner, 8); |
207 | |
208 | let mut hint = SizeHint::new(); |
209 | hint.set_upper(8); |
210 | assert_eq!(body.size_hint().upper(), hint.upper()); |
211 | |
212 | let error = body.data().await.unwrap().unwrap_err(); |
213 | assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); |
214 | } |
215 | |
216 | #[tokio::test] |
217 | async fn read_for_chunked_body_under_limit_is_okay() { |
218 | const DATA: &[&[u8]] = &[b"test" , b"ing!" ]; |
219 | let inner = Chunky(DATA); |
220 | let body = &mut Limited::new(inner, 8); |
221 | |
222 | let mut hint = SizeHint::new(); |
223 | hint.set_upper(8); |
224 | assert_eq!(body.size_hint().upper(), hint.upper()); |
225 | |
226 | let data = body.data().await.unwrap().unwrap(); |
227 | assert_eq!(data, DATA[0]); |
228 | hint.set_upper(4); |
229 | assert_eq!(body.size_hint().upper(), hint.upper()); |
230 | |
231 | let data = body.data().await.unwrap().unwrap(); |
232 | assert_eq!(data, DATA[1]); |
233 | hint.set_upper(0); |
234 | assert_eq!(body.size_hint().upper(), hint.upper()); |
235 | |
236 | assert!(matches!(body.data().await, None)); |
237 | } |
238 | |
239 | #[tokio::test] |
240 | async fn read_for_trailers_propagates_inner_trailers() { |
241 | const DATA: &[&[u8]] = &[b"test" , b"ing!" ]; |
242 | let inner = Chunky(DATA); |
243 | let body = &mut Limited::new(inner, 8); |
244 | let trailers = body.trailers().await.unwrap(); |
245 | assert_eq!(trailers, Some(HeaderMap::new())) |
246 | } |
247 | |
248 | #[derive (Debug)] |
249 | enum ErrorBodyError { |
250 | Data, |
251 | Trailers, |
252 | } |
253 | |
254 | impl fmt::Display for ErrorBodyError { |
255 | fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { |
256 | Ok(()) |
257 | } |
258 | } |
259 | |
260 | impl Error for ErrorBodyError {} |
261 | |
262 | struct ErrorBody; |
263 | |
264 | impl Body for ErrorBody { |
265 | type Data = &'static [u8]; |
266 | type Error = ErrorBodyError; |
267 | |
268 | fn poll_data( |
269 | self: Pin<&mut Self>, |
270 | _cx: &mut Context<'_>, |
271 | ) -> Poll<Option<Result<Self::Data, Self::Error>>> { |
272 | Poll::Ready(Some(Err(ErrorBodyError::Data))) |
273 | } |
274 | |
275 | fn poll_trailers( |
276 | self: Pin<&mut Self>, |
277 | _cx: &mut Context<'_>, |
278 | ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { |
279 | Poll::Ready(Err(ErrorBodyError::Trailers)) |
280 | } |
281 | } |
282 | |
283 | #[tokio::test] |
284 | async fn read_for_body_returning_error_propagates_error() { |
285 | let body = &mut Limited::new(ErrorBody, 8); |
286 | let error = body.data().await.unwrap().unwrap_err(); |
287 | assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data))); |
288 | } |
289 | |
290 | #[tokio::test] |
291 | async fn trailers_for_body_returning_error_propagates_error() { |
292 | let body = &mut Limited::new(ErrorBody, 8); |
293 | let error = body.trailers().await.unwrap_err(); |
294 | assert!(matches!( |
295 | error.downcast_ref(), |
296 | Some(ErrorBodyError::Trailers) |
297 | )); |
298 | } |
299 | } |
300 | |