1use std::fmt;
2use std::io::IoSlice;
3
4use bytes::buf::{Chain, Take};
5use bytes::Buf;
6use tracing::trace;
7
8use super::io::WriteBuf;
9
10type StaticBuf = &'static [u8];
11
12/// Encoders to handle different Transfer-Encodings.
13#[derive(Debug, Clone, PartialEq)]
14pub(crate) struct Encoder {
15 kind: Kind,
16 is_last: bool,
17}
18
19#[derive(Debug)]
20pub(crate) struct EncodedBuf<B> {
21 kind: BufKind<B>,
22}
23
24#[derive(Debug)]
25pub(crate) struct NotEof(u64);
26
27#[derive(Debug, PartialEq, Clone)]
28enum Kind {
29 /// An Encoder for when Transfer-Encoding includes `chunked`.
30 Chunked,
31 /// An Encoder for when Content-Length is set.
32 ///
33 /// Enforces that the body is not longer than the Content-Length header.
34 Length(u64),
35 /// An Encoder for when neither Content-Length nor Chunked encoding is set.
36 ///
37 /// This is mostly only used with HTTP/1.0 with a length. This kind requires
38 /// the connection to be closed when the body is finished.
39 #[cfg(feature = "server")]
40 CloseDelimited,
41}
42
43#[derive(Debug)]
44enum BufKind<B> {
45 Exact(B),
46 Limited(Take<B>),
47 Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
48 ChunkedEnd(StaticBuf),
49}
50
51impl Encoder {
52 fn new(kind: Kind) -> Encoder {
53 Encoder {
54 kind,
55 is_last: false,
56 }
57 }
58 pub(crate) fn chunked() -> Encoder {
59 Encoder::new(Kind::Chunked)
60 }
61
62 pub(crate) fn length(len: u64) -> Encoder {
63 Encoder::new(Kind::Length(len))
64 }
65
66 #[cfg(feature = "server")]
67 pub(crate) fn close_delimited() -> Encoder {
68 Encoder::new(Kind::CloseDelimited)
69 }
70
71 pub(crate) fn is_eof(&self) -> bool {
72 matches!(self.kind, Kind::Length(0))
73 }
74
75 #[cfg(feature = "server")]
76 pub(crate) fn set_last(mut self, is_last: bool) -> Self {
77 self.is_last = is_last;
78 self
79 }
80
81 pub(crate) fn is_last(&self) -> bool {
82 self.is_last
83 }
84
85 pub(crate) fn is_close_delimited(&self) -> bool {
86 match self.kind {
87 #[cfg(feature = "server")]
88 Kind::CloseDelimited => true,
89 _ => false,
90 }
91 }
92
93 pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
94 match self.kind {
95 Kind::Length(0) => Ok(None),
96 Kind::Chunked => Ok(Some(EncodedBuf {
97 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
98 })),
99 #[cfg(feature = "server")]
100 Kind::CloseDelimited => Ok(None),
101 Kind::Length(n) => Err(NotEof(n)),
102 }
103 }
104
105 pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
106 where
107 B: Buf,
108 {
109 let len = msg.remaining();
110 debug_assert!(len > 0, "encode() called with empty buf");
111
112 let kind = match self.kind {
113 Kind::Chunked => {
114 trace!("encoding chunked {}B", len);
115 let buf = ChunkSize::new(len)
116 .chain(msg)
117 .chain(b"\r\n" as &'static [u8]);
118 BufKind::Chunked(buf)
119 }
120 Kind::Length(ref mut remaining) => {
121 trace!("sized write, len = {}", len);
122 if len as u64 > *remaining {
123 let limit = *remaining as usize;
124 *remaining = 0;
125 BufKind::Limited(msg.take(limit))
126 } else {
127 *remaining -= len as u64;
128 BufKind::Exact(msg)
129 }
130 }
131 #[cfg(feature = "server")]
132 Kind::CloseDelimited => {
133 trace!("close delimited write {}B", len);
134 BufKind::Exact(msg)
135 }
136 };
137 EncodedBuf { kind }
138 }
139
140 pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
141 where
142 B: Buf,
143 {
144 let len = msg.remaining();
145 debug_assert!(len > 0, "encode() called with empty buf");
146
147 match self.kind {
148 Kind::Chunked => {
149 trace!("encoding chunked {}B", len);
150 let buf = ChunkSize::new(len)
151 .chain(msg)
152 .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
153 dst.buffer(buf);
154 !self.is_last
155 }
156 Kind::Length(remaining) => {
157 use std::cmp::Ordering;
158
159 trace!("sized write, len = {}", len);
160 match (len as u64).cmp(&remaining) {
161 Ordering::Equal => {
162 dst.buffer(msg);
163 !self.is_last
164 }
165 Ordering::Greater => {
166 dst.buffer(msg.take(remaining as usize));
167 !self.is_last
168 }
169 Ordering::Less => {
170 dst.buffer(msg);
171 false
172 }
173 }
174 }
175 #[cfg(feature = "server")]
176 Kind::CloseDelimited => {
177 trace!("close delimited write {}B", len);
178 dst.buffer(msg);
179 false
180 }
181 }
182 }
183
184 /// Encodes the full body, without verifying the remaining length matches.
185 ///
186 /// This is used in conjunction with HttpBody::__hyper_full_data(), which
187 /// means we can trust that the buf has the correct size (the buf itself
188 /// was checked to make the headers).
189 pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>)
190 where
191 B: Buf,
192 {
193 debug_assert!(msg.remaining() > 0, "encode() called with empty buf");
194 debug_assert!(
195 match self.kind {
196 Kind::Length(len) => len == msg.remaining() as u64,
197 _ => true,
198 },
199 "danger_full_buf length mismatches"
200 );
201
202 match self.kind {
203 Kind::Chunked => {
204 let len = msg.remaining();
205 trace!("encoding chunked {}B", len);
206 let buf = ChunkSize::new(len)
207 .chain(msg)
208 .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
209 dst.buffer(buf);
210 }
211 _ => {
212 dst.buffer(msg);
213 }
214 }
215 }
216}
217
218impl<B> Buf for EncodedBuf<B>
219where
220 B: Buf,
221{
222 #[inline]
223 fn remaining(&self) -> usize {
224 match self.kind {
225 BufKind::Exact(ref b) => b.remaining(),
226 BufKind::Limited(ref b) => b.remaining(),
227 BufKind::Chunked(ref b) => b.remaining(),
228 BufKind::ChunkedEnd(ref b) => b.remaining(),
229 }
230 }
231
232 #[inline]
233 fn chunk(&self) -> &[u8] {
234 match self.kind {
235 BufKind::Exact(ref b) => b.chunk(),
236 BufKind::Limited(ref b) => b.chunk(),
237 BufKind::Chunked(ref b) => b.chunk(),
238 BufKind::ChunkedEnd(ref b) => b.chunk(),
239 }
240 }
241
242 #[inline]
243 fn advance(&mut self, cnt: usize) {
244 match self.kind {
245 BufKind::Exact(ref mut b) => b.advance(cnt),
246 BufKind::Limited(ref mut b) => b.advance(cnt),
247 BufKind::Chunked(ref mut b) => b.advance(cnt),
248 BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
249 }
250 }
251
252 #[inline]
253 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
254 match self.kind {
255 BufKind::Exact(ref b) => b.chunks_vectored(dst),
256 BufKind::Limited(ref b) => b.chunks_vectored(dst),
257 BufKind::Chunked(ref b) => b.chunks_vectored(dst),
258 BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
259 }
260 }
261}
262
263#[cfg(target_pointer_width = "32")]
264const USIZE_BYTES: usize = 4;
265
266#[cfg(target_pointer_width = "64")]
267const USIZE_BYTES: usize = 8;
268
269// each byte will become 2 hex
270const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
271
272#[derive(Clone, Copy)]
273struct ChunkSize {
274 bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
275 pos: u8,
276 len: u8,
277}
278
279impl ChunkSize {
280 fn new(len: usize) -> ChunkSize {
281 use std::fmt::Write;
282 let mut size: ChunkSize = ChunkSize {
283 bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
284 pos: 0,
285 len: 0,
286 };
287 write!(&mut size, "{:X}\r\n", len).expect(msg:"CHUNK_SIZE_MAX_BYTES should fit any usize");
288 size
289 }
290}
291
292impl Buf for ChunkSize {
293 #[inline]
294 fn remaining(&self) -> usize {
295 (self.len - self.pos).into()
296 }
297
298 #[inline]
299 fn chunk(&self) -> &[u8] {
300 &self.bytes[self.pos.into()..self.len.into()]
301 }
302
303 #[inline]
304 fn advance(&mut self, cnt: usize) {
305 assert!(cnt <= self.remaining());
306 self.pos += cnt as u8; // just asserted cnt fits in u8
307 }
308}
309
310impl fmt::Debug for ChunkSize {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 f&mut DebugStruct<'_, '_>.debug_struct("ChunkSize")
313 .field("bytes", &&self.bytes[..self.len.into()])
314 .field(name:"pos", &self.pos)
315 .finish()
316 }
317}
318
319impl fmt::Write for ChunkSize {
320 fn write_str(&mut self, num: &str) -> fmt::Result {
321 use std::io::Write;
322 (&mut self.bytes[self.len.into()..])
323 .write_all(num.as_bytes())
324 .expect(msg:"&mut [u8].write() cannot error");
325 self.len += num.len() as u8; // safe because bytes is never bigger than 256
326 Ok(())
327 }
328}
329
330impl<B: Buf> From<B> for EncodedBuf<B> {
331 fn from(buf: B) -> Self {
332 EncodedBuf {
333 kind: BufKind::Exact(buf),
334 }
335 }
336}
337
338impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
339 fn from(buf: Take<B>) -> Self {
340 EncodedBuf {
341 kind: BufKind::Limited(buf),
342 }
343 }
344}
345
346impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
347 fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
348 EncodedBuf {
349 kind: BufKind::Chunked(buf),
350 }
351 }
352}
353
354impl fmt::Display for NotEof {
355 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356 write!(f, "early end, expected {} more bytes", self.0)
357 }
358}
359
360impl std::error::Error for NotEof {}
361
362#[cfg(test)]
363mod tests {
364 use bytes::BufMut;
365
366 use super::super::io::Cursor;
367 use super::Encoder;
368
369 #[test]
370 fn chunked() {
371 let mut encoder = Encoder::chunked();
372 let mut dst = Vec::new();
373
374 let msg1 = b"foo bar".as_ref();
375 let buf1 = encoder.encode(msg1);
376 dst.put(buf1);
377 assert_eq!(dst, b"7\r\nfoo bar\r\n");
378
379 let msg2 = b"baz quux herp".as_ref();
380 let buf2 = encoder.encode(msg2);
381 dst.put(buf2);
382
383 assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
384
385 let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
386 dst.put(end);
387
388 assert_eq!(
389 dst,
390 b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
391 );
392 }
393
394 #[test]
395 fn length() {
396 let max_len = 8;
397 let mut encoder = Encoder::length(max_len as u64);
398 let mut dst = Vec::new();
399
400 let msg1 = b"foo bar".as_ref();
401 let buf1 = encoder.encode(msg1);
402 dst.put(buf1);
403
404 assert_eq!(dst, b"foo bar");
405 assert!(!encoder.is_eof());
406 encoder.end::<()>().unwrap_err();
407
408 let msg2 = b"baz".as_ref();
409 let buf2 = encoder.encode(msg2);
410 dst.put(buf2);
411
412 assert_eq!(dst.len(), max_len);
413 assert_eq!(dst, b"foo barb");
414 assert!(encoder.is_eof());
415 assert!(encoder.end::<()>().unwrap().is_none());
416 }
417
418 #[test]
419 fn eof() {
420 let mut encoder = Encoder::close_delimited();
421 let mut dst = Vec::new();
422
423 let msg1 = b"foo bar".as_ref();
424 let buf1 = encoder.encode(msg1);
425 dst.put(buf1);
426
427 assert_eq!(dst, b"foo bar");
428 assert!(!encoder.is_eof());
429 encoder.end::<()>().unwrap();
430
431 let msg2 = b"baz".as_ref();
432 let buf2 = encoder.encode(msg2);
433 dst.put(buf2);
434
435 assert_eq!(dst, b"foo barbaz");
436 assert!(!encoder.is_eof());
437 encoder.end::<()>().unwrap();
438 }
439}
440