1 | use std::fmt; |
2 | use std::io::IoSlice; |
3 | |
4 | use bytes::buf::{Chain, Take}; |
5 | use bytes::Buf; |
6 | use tracing::trace; |
7 | |
8 | use super::io::WriteBuf; |
9 | |
10 | type StaticBuf = &'static [u8]; |
11 | |
12 | /// Encoders to handle different Transfer-Encodings. |
13 | #[derive (Debug, Clone, PartialEq)] |
14 | pub(crate) struct Encoder { |
15 | kind: Kind, |
16 | is_last: bool, |
17 | } |
18 | |
19 | #[derive (Debug)] |
20 | pub(crate) struct EncodedBuf<B> { |
21 | kind: BufKind<B>, |
22 | } |
23 | |
24 | #[derive (Debug)] |
25 | pub(crate) struct NotEof(u64); |
26 | |
27 | #[derive (Debug, PartialEq, Clone)] |
28 | enum Kind { |
29 | /// An Encoder for when Transfer-Encoding includes `chunked`. |
30 | Chunked, |
31 | /// An Encoder for when Content-Length is set. |
32 | /// |
33 | /// Enforces that the body is not longer than the Content-Length header. |
34 | Length(u64), |
35 | /// An Encoder for when neither Content-Length nor Chunked encoding is set. |
36 | /// |
37 | /// This is mostly only used with HTTP/1.0 with a length. This kind requires |
38 | /// the connection to be closed when the body is finished. |
39 | #[cfg (feature = "server" )] |
40 | CloseDelimited, |
41 | } |
42 | |
43 | #[derive (Debug)] |
44 | enum BufKind<B> { |
45 | Exact(B), |
46 | Limited(Take<B>), |
47 | Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>), |
48 | ChunkedEnd(StaticBuf), |
49 | } |
50 | |
51 | impl Encoder { |
52 | fn new(kind: Kind) -> Encoder { |
53 | Encoder { |
54 | kind, |
55 | is_last: false, |
56 | } |
57 | } |
58 | pub(crate) fn chunked() -> Encoder { |
59 | Encoder::new(Kind::Chunked) |
60 | } |
61 | |
62 | pub(crate) fn length(len: u64) -> Encoder { |
63 | Encoder::new(Kind::Length(len)) |
64 | } |
65 | |
66 | #[cfg (feature = "server" )] |
67 | pub(crate) fn close_delimited() -> Encoder { |
68 | Encoder::new(Kind::CloseDelimited) |
69 | } |
70 | |
71 | pub(crate) fn is_eof(&self) -> bool { |
72 | matches!(self.kind, Kind::Length(0)) |
73 | } |
74 | |
75 | #[cfg (feature = "server" )] |
76 | pub(crate) fn set_last(mut self, is_last: bool) -> Self { |
77 | self.is_last = is_last; |
78 | self |
79 | } |
80 | |
81 | pub(crate) fn is_last(&self) -> bool { |
82 | self.is_last |
83 | } |
84 | |
85 | pub(crate) fn is_close_delimited(&self) -> bool { |
86 | match self.kind { |
87 | #[cfg (feature = "server" )] |
88 | Kind::CloseDelimited => true, |
89 | _ => false, |
90 | } |
91 | } |
92 | |
93 | pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> { |
94 | match self.kind { |
95 | Kind::Length(0) => Ok(None), |
96 | Kind::Chunked => Ok(Some(EncodedBuf { |
97 | kind: BufKind::ChunkedEnd(b"0 \r\n\r\n" ), |
98 | })), |
99 | #[cfg (feature = "server" )] |
100 | Kind::CloseDelimited => Ok(None), |
101 | Kind::Length(n) => Err(NotEof(n)), |
102 | } |
103 | } |
104 | |
105 | pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B> |
106 | where |
107 | B: Buf, |
108 | { |
109 | let len = msg.remaining(); |
110 | debug_assert!(len > 0, "encode() called with empty buf" ); |
111 | |
112 | let kind = match self.kind { |
113 | Kind::Chunked => { |
114 | trace!("encoding chunked {}B" , len); |
115 | let buf = ChunkSize::new(len) |
116 | .chain(msg) |
117 | .chain(b" \r\n" as &'static [u8]); |
118 | BufKind::Chunked(buf) |
119 | } |
120 | Kind::Length(ref mut remaining) => { |
121 | trace!("sized write, len = {}" , len); |
122 | if len as u64 > *remaining { |
123 | let limit = *remaining as usize; |
124 | *remaining = 0; |
125 | BufKind::Limited(msg.take(limit)) |
126 | } else { |
127 | *remaining -= len as u64; |
128 | BufKind::Exact(msg) |
129 | } |
130 | } |
131 | #[cfg (feature = "server" )] |
132 | Kind::CloseDelimited => { |
133 | trace!("close delimited write {}B" , len); |
134 | BufKind::Exact(msg) |
135 | } |
136 | }; |
137 | EncodedBuf { kind } |
138 | } |
139 | |
140 | pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool |
141 | where |
142 | B: Buf, |
143 | { |
144 | let len = msg.remaining(); |
145 | debug_assert!(len > 0, "encode() called with empty buf" ); |
146 | |
147 | match self.kind { |
148 | Kind::Chunked => { |
149 | trace!("encoding chunked {}B" , len); |
150 | let buf = ChunkSize::new(len) |
151 | .chain(msg) |
152 | .chain(b" \r\n0 \r\n\r\n" as &'static [u8]); |
153 | dst.buffer(buf); |
154 | !self.is_last |
155 | } |
156 | Kind::Length(remaining) => { |
157 | use std::cmp::Ordering; |
158 | |
159 | trace!("sized write, len = {}" , len); |
160 | match (len as u64).cmp(&remaining) { |
161 | Ordering::Equal => { |
162 | dst.buffer(msg); |
163 | !self.is_last |
164 | } |
165 | Ordering::Greater => { |
166 | dst.buffer(msg.take(remaining as usize)); |
167 | !self.is_last |
168 | } |
169 | Ordering::Less => { |
170 | dst.buffer(msg); |
171 | false |
172 | } |
173 | } |
174 | } |
175 | #[cfg (feature = "server" )] |
176 | Kind::CloseDelimited => { |
177 | trace!("close delimited write {}B" , len); |
178 | dst.buffer(msg); |
179 | false |
180 | } |
181 | } |
182 | } |
183 | |
184 | /// Encodes the full body, without verifying the remaining length matches. |
185 | /// |
186 | /// This is used in conjunction with HttpBody::__hyper_full_data(), which |
187 | /// means we can trust that the buf has the correct size (the buf itself |
188 | /// was checked to make the headers). |
189 | pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) |
190 | where |
191 | B: Buf, |
192 | { |
193 | debug_assert!(msg.remaining() > 0, "encode() called with empty buf" ); |
194 | debug_assert!( |
195 | match self.kind { |
196 | Kind::Length(len) => len == msg.remaining() as u64, |
197 | _ => true, |
198 | }, |
199 | "danger_full_buf length mismatches" |
200 | ); |
201 | |
202 | match self.kind { |
203 | Kind::Chunked => { |
204 | let len = msg.remaining(); |
205 | trace!("encoding chunked {}B" , len); |
206 | let buf = ChunkSize::new(len) |
207 | .chain(msg) |
208 | .chain(b" \r\n0 \r\n\r\n" as &'static [u8]); |
209 | dst.buffer(buf); |
210 | } |
211 | _ => { |
212 | dst.buffer(msg); |
213 | } |
214 | } |
215 | } |
216 | } |
217 | |
218 | impl<B> Buf for EncodedBuf<B> |
219 | where |
220 | B: Buf, |
221 | { |
222 | #[inline ] |
223 | fn remaining(&self) -> usize { |
224 | match self.kind { |
225 | BufKind::Exact(ref b) => b.remaining(), |
226 | BufKind::Limited(ref b) => b.remaining(), |
227 | BufKind::Chunked(ref b) => b.remaining(), |
228 | BufKind::ChunkedEnd(ref b) => b.remaining(), |
229 | } |
230 | } |
231 | |
232 | #[inline ] |
233 | fn chunk(&self) -> &[u8] { |
234 | match self.kind { |
235 | BufKind::Exact(ref b) => b.chunk(), |
236 | BufKind::Limited(ref b) => b.chunk(), |
237 | BufKind::Chunked(ref b) => b.chunk(), |
238 | BufKind::ChunkedEnd(ref b) => b.chunk(), |
239 | } |
240 | } |
241 | |
242 | #[inline ] |
243 | fn advance(&mut self, cnt: usize) { |
244 | match self.kind { |
245 | BufKind::Exact(ref mut b) => b.advance(cnt), |
246 | BufKind::Limited(ref mut b) => b.advance(cnt), |
247 | BufKind::Chunked(ref mut b) => b.advance(cnt), |
248 | BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), |
249 | } |
250 | } |
251 | |
252 | #[inline ] |
253 | fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { |
254 | match self.kind { |
255 | BufKind::Exact(ref b) => b.chunks_vectored(dst), |
256 | BufKind::Limited(ref b) => b.chunks_vectored(dst), |
257 | BufKind::Chunked(ref b) => b.chunks_vectored(dst), |
258 | BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), |
259 | } |
260 | } |
261 | } |
262 | |
263 | #[cfg (target_pointer_width = "32" )] |
264 | const USIZE_BYTES: usize = 4; |
265 | |
266 | #[cfg (target_pointer_width = "64" )] |
267 | const USIZE_BYTES: usize = 8; |
268 | |
269 | // each byte will become 2 hex |
270 | const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2; |
271 | |
272 | #[derive (Clone, Copy)] |
273 | struct ChunkSize { |
274 | bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2], |
275 | pos: u8, |
276 | len: u8, |
277 | } |
278 | |
279 | impl ChunkSize { |
280 | fn new(len: usize) -> ChunkSize { |
281 | use std::fmt::Write; |
282 | let mut size: ChunkSize = ChunkSize { |
283 | bytes: [0; CHUNK_SIZE_MAX_BYTES + 2], |
284 | pos: 0, |
285 | len: 0, |
286 | }; |
287 | write!(&mut size, " {:X}\r\n" , len).expect(msg:"CHUNK_SIZE_MAX_BYTES should fit any usize" ); |
288 | size |
289 | } |
290 | } |
291 | |
292 | impl Buf for ChunkSize { |
293 | #[inline ] |
294 | fn remaining(&self) -> usize { |
295 | (self.len - self.pos).into() |
296 | } |
297 | |
298 | #[inline ] |
299 | fn chunk(&self) -> &[u8] { |
300 | &self.bytes[self.pos.into()..self.len.into()] |
301 | } |
302 | |
303 | #[inline ] |
304 | fn advance(&mut self, cnt: usize) { |
305 | assert!(cnt <= self.remaining()); |
306 | self.pos += cnt as u8; // just asserted cnt fits in u8 |
307 | } |
308 | } |
309 | |
310 | impl fmt::Debug for ChunkSize { |
311 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
312 | f&mut DebugStruct<'_, '_>.debug_struct("ChunkSize" ) |
313 | .field("bytes" , &&self.bytes[..self.len.into()]) |
314 | .field(name:"pos" , &self.pos) |
315 | .finish() |
316 | } |
317 | } |
318 | |
319 | impl fmt::Write for ChunkSize { |
320 | fn write_str(&mut self, num: &str) -> fmt::Result { |
321 | use std::io::Write; |
322 | (&mut self.bytes[self.len.into()..]) |
323 | .write_all(num.as_bytes()) |
324 | .expect(msg:"&mut [u8].write() cannot error" ); |
325 | self.len += num.len() as u8; // safe because bytes is never bigger than 256 |
326 | Ok(()) |
327 | } |
328 | } |
329 | |
330 | impl<B: Buf> From<B> for EncodedBuf<B> { |
331 | fn from(buf: B) -> Self { |
332 | EncodedBuf { |
333 | kind: BufKind::Exact(buf), |
334 | } |
335 | } |
336 | } |
337 | |
338 | impl<B: Buf> From<Take<B>> for EncodedBuf<B> { |
339 | fn from(buf: Take<B>) -> Self { |
340 | EncodedBuf { |
341 | kind: BufKind::Limited(buf), |
342 | } |
343 | } |
344 | } |
345 | |
346 | impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> { |
347 | fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self { |
348 | EncodedBuf { |
349 | kind: BufKind::Chunked(buf), |
350 | } |
351 | } |
352 | } |
353 | |
354 | impl fmt::Display for NotEof { |
355 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
356 | write!(f, "early end, expected {} more bytes" , self.0) |
357 | } |
358 | } |
359 | |
360 | impl std::error::Error for NotEof {} |
361 | |
362 | #[cfg (test)] |
363 | mod tests { |
364 | use bytes::BufMut; |
365 | |
366 | use super::super::io::Cursor; |
367 | use super::Encoder; |
368 | |
369 | #[test ] |
370 | fn chunked() { |
371 | let mut encoder = Encoder::chunked(); |
372 | let mut dst = Vec::new(); |
373 | |
374 | let msg1 = b"foo bar" .as_ref(); |
375 | let buf1 = encoder.encode(msg1); |
376 | dst.put(buf1); |
377 | assert_eq!(dst, b"7 \r\nfoo bar \r\n" ); |
378 | |
379 | let msg2 = b"baz quux herp" .as_ref(); |
380 | let buf2 = encoder.encode(msg2); |
381 | dst.put(buf2); |
382 | |
383 | assert_eq!(dst, b"7 \r\nfoo bar \r\nD \r\nbaz quux herp \r\n" ); |
384 | |
385 | let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap(); |
386 | dst.put(end); |
387 | |
388 | assert_eq!( |
389 | dst, |
390 | b"7 \r\nfoo bar \r\nD \r\nbaz quux herp \r\n0 \r\n\r\n" .as_ref() |
391 | ); |
392 | } |
393 | |
394 | #[test ] |
395 | fn length() { |
396 | let max_len = 8; |
397 | let mut encoder = Encoder::length(max_len as u64); |
398 | let mut dst = Vec::new(); |
399 | |
400 | let msg1 = b"foo bar" .as_ref(); |
401 | let buf1 = encoder.encode(msg1); |
402 | dst.put(buf1); |
403 | |
404 | assert_eq!(dst, b"foo bar" ); |
405 | assert!(!encoder.is_eof()); |
406 | encoder.end::<()>().unwrap_err(); |
407 | |
408 | let msg2 = b"baz" .as_ref(); |
409 | let buf2 = encoder.encode(msg2); |
410 | dst.put(buf2); |
411 | |
412 | assert_eq!(dst.len(), max_len); |
413 | assert_eq!(dst, b"foo barb" ); |
414 | assert!(encoder.is_eof()); |
415 | assert!(encoder.end::<()>().unwrap().is_none()); |
416 | } |
417 | |
418 | #[test ] |
419 | fn eof() { |
420 | let mut encoder = Encoder::close_delimited(); |
421 | let mut dst = Vec::new(); |
422 | |
423 | let msg1 = b"foo bar" .as_ref(); |
424 | let buf1 = encoder.encode(msg1); |
425 | dst.put(buf1); |
426 | |
427 | assert_eq!(dst, b"foo bar" ); |
428 | assert!(!encoder.is_eof()); |
429 | encoder.end::<()>().unwrap(); |
430 | |
431 | let msg2 = b"baz" .as_ref(); |
432 | let buf2 = encoder.encode(msg2); |
433 | dst.put(buf2); |
434 | |
435 | assert_eq!(dst, b"foo barbaz" ); |
436 | assert!(!encoder.is_eof()); |
437 | encoder.end::<()>().unwrap(); |
438 | } |
439 | } |
440 | |