1// Copyright 2015 The tiny-http Contributors
2// Copyright 2015 The rust-chunked-transfer Contributors
3// Forked into ureq, 2022, from https://github.com/frewsxcv/rust-chunked-transfer
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// https://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use std::error::Error;
18use std::fmt;
19use std::io::Error as IoError;
20use std::io::ErrorKind;
21use std::io::Read;
22use std::io::Result as IoResult;
23
24/// Reads HTTP chunks and sends back real data.
25///
26/// # Example
27///
28/// ```no_compile
29/// use chunked_transfer::Decoder;
30/// use std::io::Read;
31///
32/// let encoded = b"3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n";
33/// let mut decoded = String::new();
34///
35/// let mut decoder = Decoder::new(encoded as &[u8]);
36/// decoder.read_to_string(&mut decoded);
37///
38/// assert_eq!(decoded, "hello world!!!");
39/// ```
40pub struct Decoder<R> {
41 // where the chunks come from
42 source: R,
43
44 // remaining size of the chunk being read
45 // none if we are not in a chunk
46 remaining_chunks_size: Option<usize>,
47}
48
49impl<R> Decoder<R>
50where
51 R: Read,
52{
53 pub fn new(source: R) -> Decoder<R> {
54 Decoder {
55 source,
56 remaining_chunks_size: None,
57 }
58 }
59
60 /// Returns the remaining bytes left in the chunk being read.
61 pub fn remaining_chunks_size(&self) -> Option<usize> {
62 self.remaining_chunks_size
63 }
64
65 /// Unwraps the Decoder into its inner `Read` source.
66 pub fn into_inner(self) -> R {
67 self.source
68 }
69
70 fn read_chunk_size(&mut self) -> IoResult<usize> {
71 let mut chunk_size_bytes = Vec::new();
72 let mut has_ext = false;
73
74 loop {
75 let byte = match self.source.by_ref().bytes().next() {
76 Some(b) => b?,
77 None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
78 };
79
80 if byte == b'\r' {
81 break;
82 }
83
84 if byte == b';' {
85 has_ext = true;
86 break;
87 }
88
89 chunk_size_bytes.push(byte);
90 }
91
92 // Ignore extensions for now
93 if has_ext {
94 loop {
95 let byte = match self.source.by_ref().bytes().next() {
96 Some(b) => b?,
97 None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
98 };
99 if byte == b'\r' {
100 break;
101 }
102 }
103 }
104
105 self.read_line_feed()?;
106
107 let chunk_size = String::from_utf8(chunk_size_bytes)
108 .ok()
109 .and_then(|c| usize::from_str_radix(c.trim(), 16).ok())
110 .ok_or_else(|| IoError::new(ErrorKind::InvalidInput, DecoderError))?;
111
112 Ok(chunk_size)
113 }
114
115 fn read_carriage_return(&mut self) -> IoResult<()> {
116 match self.source.by_ref().bytes().next() {
117 Some(Ok(b'\r')) => Ok(()),
118 _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
119 }
120 }
121
122 fn read_line_feed(&mut self) -> IoResult<()> {
123 match self.source.by_ref().bytes().next() {
124 Some(Ok(b'\n')) => Ok(()),
125 _ => Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
126 }
127 }
128}
129
130impl<R> Read for Decoder<R>
131where
132 R: Read,
133{
134 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
135 let remaining_chunks_size = match self.remaining_chunks_size {
136 Some(c) => c,
137 None => {
138 // first possibility: we are not in a chunk, so we'll attempt to determine
139 // the chunks size
140 let chunk_size = self.read_chunk_size()?;
141
142 // if the chunk size is 0, we are at EOF
143 if chunk_size == 0 {
144 self.read_carriage_return()?;
145 self.read_line_feed()?;
146 return Ok(0);
147 }
148
149 chunk_size
150 }
151 };
152
153 // second possibility: we continue reading from a chunk
154 if buf.len() < remaining_chunks_size {
155 let read = self.source.read(buf)?;
156 self.remaining_chunks_size = Some(remaining_chunks_size - read);
157 return Ok(read);
158 }
159
160 // third possibility: the read request goes further than the current chunk
161 // we simply read until the end of the chunk and return
162 assert!(buf.len() >= remaining_chunks_size);
163
164 let buf = &mut buf[..remaining_chunks_size];
165 let read = self.source.read(buf)?;
166
167 self.remaining_chunks_size = if read == remaining_chunks_size {
168 self.read_carriage_return()?;
169 self.read_line_feed()?;
170 None
171 } else {
172 Some(remaining_chunks_size - read)
173 };
174
175 Ok(read)
176 }
177}
178
179#[derive(Debug, Copy, Clone)]
180struct DecoderError;
181
182impl fmt::Display for DecoderError {
183 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
184 write!(fmt, "Error while decoding chunks")
185 }
186}
187
188impl Error for DecoderError {
189 fn description(&self) -> &str {
190 "Error while decoding chunks"
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use super::Decoder;
197 use std::io;
198 use std::io::Read;
199
200 /// This unit test is taken from from Hyper
201 /// https://github.com/hyperium/hyper
202 /// Copyright (c) 2014 Sean McArthur
203 #[test]
204 fn test_read_chunk_size() {
205 fn read(s: &str, expected: usize) {
206 let mut decoded = Decoder::new(s.as_bytes());
207 let actual = decoded.read_chunk_size().unwrap();
208 assert_eq!(expected, actual);
209 }
210
211 fn read_err(s: &str) {
212 let mut decoded = Decoder::new(s.as_bytes());
213 let err_kind = decoded.read_chunk_size().unwrap_err().kind();
214 assert_eq!(err_kind, io::ErrorKind::InvalidInput);
215 }
216
217 read("1\r\n", 1);
218 read("01\r\n", 1);
219 read("0\r\n", 0);
220 read("00\r\n", 0);
221 read("A\r\n", 10);
222 read("a\r\n", 10);
223 read("Ff\r\n", 255);
224 read("Ff \r\n", 255);
225 // Missing LF or CRLF
226 read_err("F\rF");
227 read_err("F");
228 // Invalid hex digit
229 read_err("X\r\n");
230 read_err("1X\r\n");
231 read_err("-\r\n");
232 read_err("-1\r\n");
233 // Acceptable (if not fully valid) extensions do not influence the size
234 read("1;extension\r\n", 1);
235 read("a;ext name=value\r\n", 10);
236 read("1;extension;extension2\r\n", 1);
237 read("1;;; ;\r\n", 1);
238 read("2; extension...\r\n", 2);
239 read("3 ; extension=123\r\n", 3);
240 read("3 ;\r\n", 3);
241 read("3 ; \r\n", 3);
242 // Invalid extensions cause an error
243 read_err("1 invalid extension\r\n");
244 read_err("1 A\r\n");
245 read_err("1;no CRLF");
246 }
247
248 #[test]
249 fn test_valid_chunk_decode() {
250 let source = io::Cursor::new(
251 "3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n"
252 .to_string()
253 .into_bytes(),
254 );
255 let mut decoded = Decoder::new(source);
256
257 let mut string = String::new();
258 decoded.read_to_string(&mut string).unwrap();
259
260 assert_eq!(string, "hello world!!!");
261 }
262
263 #[test]
264 fn test_decode_zero_length() {
265 let mut decoder = Decoder::new(b"0\r\n\r\n" as &[u8]);
266
267 let mut decoded = String::new();
268 decoder.read_to_string(&mut decoded).unwrap();
269
270 assert_eq!(decoded, "");
271 }
272
273 #[test]
274 fn test_decode_invalid_chunk_length() {
275 let mut decoder = Decoder::new(b"m\r\n\r\n" as &[u8]);
276
277 let mut decoded = String::new();
278 assert!(decoder.read_to_string(&mut decoded).is_err());
279 }
280
281 #[test]
282 fn invalid_input1() {
283 let source = io::Cursor::new(
284 "2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n"
285 .to_string()
286 .into_bytes(),
287 );
288 let mut decoded = Decoder::new(source);
289
290 let mut string = String::new();
291 assert!(decoded.read_to_string(&mut string).is_err());
292 }
293
294 #[test]
295 fn invalid_input2() {
296 let source = io::Cursor::new(
297 "3\rhel\r\nb\r\nlo world!!!\r\n0\r\n"
298 .to_string()
299 .into_bytes(),
300 );
301 let mut decoded = Decoder::new(source);
302
303 let mut string = String::new();
304 assert!(decoded.read_to_string(&mut string).is_err());
305 }
306}
307