1use crate::{Body, SizeHint};
2use bytes::Buf;
3use http::HeaderMap;
4use pin_project_lite::pin_project;
5use std::error::Error;
6use std::fmt;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10pin_project! {
11 /// A length limited body.
12 ///
13 /// This body will return an error if more than the configured number
14 /// of bytes are returned on polling the wrapped body.
15 #[derive(Clone, Copy, Debug)]
16 pub struct Limited<B> {
17 remaining: usize,
18 #[pin]
19 inner: B,
20 }
21}
22
23impl<B> Limited<B> {
24 /// Create a new `Limited`.
25 pub fn new(inner: B, limit: usize) -> Self {
26 Self {
27 remaining: limit,
28 inner,
29 }
30 }
31}
32
33impl<B> Body for Limited<B>
34where
35 B: Body,
36 B::Error: Into<Box<dyn Error + Send + Sync>>,
37{
38 type Data = B::Data;
39 type Error = Box<dyn Error + Send + Sync>;
40
41 fn poll_data(
42 self: Pin<&mut Self>,
43 cx: &mut Context<'_>,
44 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
45 let this = self.project();
46 let res = match this.inner.poll_data(cx) {
47 Poll::Pending => return Poll::Pending,
48 Poll::Ready(None) => None,
49 Poll::Ready(Some(Ok(data))) => {
50 if data.remaining() > *this.remaining {
51 *this.remaining = 0;
52 Some(Err(LengthLimitError.into()))
53 } else {
54 *this.remaining -= data.remaining();
55 Some(Ok(data))
56 }
57 }
58 Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
59 };
60
61 Poll::Ready(res)
62 }
63
64 fn poll_trailers(
65 self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
68 let this = self.project();
69 let res = match this.inner.poll_trailers(cx) {
70 Poll::Pending => return Poll::Pending,
71 Poll::Ready(Ok(data)) => Ok(data),
72 Poll::Ready(Err(err)) => Err(err.into()),
73 };
74
75 Poll::Ready(res)
76 }
77
78 fn is_end_stream(&self) -> bool {
79 self.inner.is_end_stream()
80 }
81
82 fn size_hint(&self) -> SizeHint {
83 use std::convert::TryFrom;
84 match u64::try_from(self.remaining) {
85 Ok(n) => {
86 let mut hint = self.inner.size_hint();
87 if hint.lower() >= n {
88 hint.set_exact(n)
89 } else if let Some(max) = hint.upper() {
90 hint.set_upper(n.min(max))
91 } else {
92 hint.set_upper(n)
93 }
94 hint
95 }
96 Err(_) => self.inner.size_hint(),
97 }
98 }
99}
100
101/// An error returned when body length exceeds the configured limit.
102#[derive(Debug)]
103#[non_exhaustive]
104pub struct LengthLimitError;
105
106impl fmt::Display for LengthLimitError {
107 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108 f.write_str(data:"length limit exceeded")
109 }
110}
111
112impl Error for LengthLimitError {}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::Full;
118 use bytes::Bytes;
119 use std::convert::Infallible;
120
121 #[tokio::test]
122 async fn read_for_body_under_limit_returns_data() {
123 const DATA: &[u8] = b"testing";
124 let inner = Full::new(Bytes::from(DATA));
125 let body = &mut Limited::new(inner, 8);
126
127 let mut hint = SizeHint::new();
128 hint.set_upper(7);
129 assert_eq!(body.size_hint().upper(), hint.upper());
130
131 let data = body.data().await.unwrap().unwrap();
132 assert_eq!(data, DATA);
133 hint.set_upper(0);
134 assert_eq!(body.size_hint().upper(), hint.upper());
135
136 assert!(matches!(body.data().await, None));
137 }
138
139 #[tokio::test]
140 async fn read_for_body_over_limit_returns_error() {
141 const DATA: &[u8] = b"testing a string that is too long";
142 let inner = Full::new(Bytes::from(DATA));
143 let body = &mut Limited::new(inner, 8);
144
145 let mut hint = SizeHint::new();
146 hint.set_upper(8);
147 assert_eq!(body.size_hint().upper(), hint.upper());
148
149 let error = body.data().await.unwrap().unwrap_err();
150 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
151 }
152
153 struct Chunky(&'static [&'static [u8]]);
154
155 impl Body for Chunky {
156 type Data = &'static [u8];
157 type Error = Infallible;
158
159 fn poll_data(
160 self: Pin<&mut Self>,
161 _cx: &mut Context<'_>,
162 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
163 let mut this = self;
164 match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
165 Some((data, new_tail)) => {
166 this.0 = new_tail;
167
168 Poll::Ready(Some(data))
169 }
170 None => Poll::Ready(None),
171 }
172 }
173
174 fn poll_trailers(
175 self: Pin<&mut Self>,
176 _cx: &mut Context<'_>,
177 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
178 Poll::Ready(Ok(Some(HeaderMap::new())))
179 }
180 }
181
182 #[tokio::test]
183 async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
184 ) {
185 const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
186 let inner = Chunky(DATA);
187 let body = &mut Limited::new(inner, 8);
188
189 let mut hint = SizeHint::new();
190 hint.set_upper(8);
191 assert_eq!(body.size_hint().upper(), hint.upper());
192
193 let data = body.data().await.unwrap().unwrap();
194 assert_eq!(data, DATA[0]);
195 hint.set_upper(0);
196 assert_eq!(body.size_hint().upper(), hint.upper());
197
198 let error = body.data().await.unwrap().unwrap_err();
199 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
200 }
201
202 #[tokio::test]
203 async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
204 const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
205 let inner = Chunky(DATA);
206 let body = &mut Limited::new(inner, 8);
207
208 let mut hint = SizeHint::new();
209 hint.set_upper(8);
210 assert_eq!(body.size_hint().upper(), hint.upper());
211
212 let error = body.data().await.unwrap().unwrap_err();
213 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
214 }
215
216 #[tokio::test]
217 async fn read_for_chunked_body_under_limit_is_okay() {
218 const DATA: &[&[u8]] = &[b"test", b"ing!"];
219 let inner = Chunky(DATA);
220 let body = &mut Limited::new(inner, 8);
221
222 let mut hint = SizeHint::new();
223 hint.set_upper(8);
224 assert_eq!(body.size_hint().upper(), hint.upper());
225
226 let data = body.data().await.unwrap().unwrap();
227 assert_eq!(data, DATA[0]);
228 hint.set_upper(4);
229 assert_eq!(body.size_hint().upper(), hint.upper());
230
231 let data = body.data().await.unwrap().unwrap();
232 assert_eq!(data, DATA[1]);
233 hint.set_upper(0);
234 assert_eq!(body.size_hint().upper(), hint.upper());
235
236 assert!(matches!(body.data().await, None));
237 }
238
239 #[tokio::test]
240 async fn read_for_trailers_propagates_inner_trailers() {
241 const DATA: &[&[u8]] = &[b"test", b"ing!"];
242 let inner = Chunky(DATA);
243 let body = &mut Limited::new(inner, 8);
244 let trailers = body.trailers().await.unwrap();
245 assert_eq!(trailers, Some(HeaderMap::new()))
246 }
247
248 #[derive(Debug)]
249 enum ErrorBodyError {
250 Data,
251 Trailers,
252 }
253
254 impl fmt::Display for ErrorBodyError {
255 fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
256 Ok(())
257 }
258 }
259
260 impl Error for ErrorBodyError {}
261
262 struct ErrorBody;
263
264 impl Body for ErrorBody {
265 type Data = &'static [u8];
266 type Error = ErrorBodyError;
267
268 fn poll_data(
269 self: Pin<&mut Self>,
270 _cx: &mut Context<'_>,
271 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
272 Poll::Ready(Some(Err(ErrorBodyError::Data)))
273 }
274
275 fn poll_trailers(
276 self: Pin<&mut Self>,
277 _cx: &mut Context<'_>,
278 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
279 Poll::Ready(Err(ErrorBodyError::Trailers))
280 }
281 }
282
283 #[tokio::test]
284 async fn read_for_body_returning_error_propagates_error() {
285 let body = &mut Limited::new(ErrorBody, 8);
286 let error = body.data().await.unwrap().unwrap_err();
287 assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
288 }
289
290 #[tokio::test]
291 async fn trailers_for_body_returning_error_propagates_error() {
292 let body = &mut Limited::new(ErrorBody, 8);
293 let error = body.trailers().await.unwrap_err();
294 assert!(matches!(
295 error.downcast_ref(),
296 Some(ErrorBodyError::Trailers)
297 ));
298 }
299}
300