1use super::{
2 io, ready, slice_from, AsyncRead, Buffer, Chunked, Context, End, ErrorKind, InvalidChunkSize,
3 PartialChunkSize, Pin, Ready, ReceivedBody, ReceivedBodyState, StateOutput, Status,
4};
5
6impl<'conn, Transport> ReceivedBody<'conn, Transport>
7where
8 Transport: AsyncRead + Unpin + Send + Sync + 'static,
9{
10 #[inline]
11 pub(super) fn handle_chunked(
12 &mut self,
13 cx: &mut Context<'_>,
14 buf: &mut [u8],
15 remaining: u64,
16 total: u64,
17 ) -> StateOutput {
18 let bytes = ready!(self.read_raw(cx, buf)?);
19
20 Ready(chunk_decode(
21 &mut self.buffer,
22 remaining,
23 total,
24 &mut buf[..bytes],
25 self.max_len,
26 ))
27 }
28
29 #[inline]
30 pub(super) fn handle_partial(
31 &mut self,
32 cx: &mut Context<'_>,
33 buf: &mut [u8],
34 total: u64,
35 ) -> StateOutput {
36 let transport = self
37 .transport
38 .as_deref_mut()
39 .ok_or_else(|| io::Error::from(ErrorKind::NotConnected))?;
40 let bytes = ready!(Pin::new(transport).poll_read(cx, buf))?;
41
42 if bytes == 0 {
43 return Ready(Err(io::Error::from(ErrorKind::ConnectionAborted)));
44 }
45
46 self.buffer.extend_from_slice(&buf[..bytes]);
47
48 match httparse::parse_chunk_size(&self.buffer) {
49 Ok(Status::Complete((framing_bytes, remaining))) => {
50 self.buffer.ignore_front(framing_bytes);
51 Ready(Ok((
52 if remaining == 0 {
53 End
54 } else {
55 Chunked {
56 remaining: remaining + 2,
57 total,
58 }
59 },
60 0,
61 )))
62 }
63
64 Ok(Status::Partial) => Ready(Ok((PartialChunkSize { total }, 0))),
65
66 Err(InvalidChunkSize) => Ready(Err(io::Error::new(
67 ErrorKind::InvalidData,
68 "invalid chunk framing",
69 ))),
70 }
71 }
72}
73
74pub(super) fn chunk_decode(
75 self_buffer: &mut Buffer,
76 remaining: u64,
77 mut total: u64,
78 buf: &mut [u8],
79 max_len: u64,
80) -> io::Result<(ReceivedBodyState, usize)> {
81 if buf.is_empty() {
82 return Err(io::Error::from(ErrorKind::ConnectionAborted));
83 }
84 let mut ranges_to_keep = vec![];
85 let mut chunk_start = 0u64;
86 let mut chunk_end = remaining;
87 let request_body_state = loop {
88 if chunk_end > 2 {
89 let keep_start = usize::try_from(chunk_start).unwrap_or(usize::MAX);
90 let keep_end = buf
91 .len()
92 .min(usize::try_from(chunk_end - 2).unwrap_or(usize::MAX));
93 ranges_to_keep.push(keep_start..keep_end);
94 let new_bytes = (keep_end - keep_start) as u64;
95 total += new_bytes;
96 if total > max_len {
97 return Err(io::Error::new(ErrorKind::Unsupported, "content too long"));
98 }
99 }
100 chunk_start = chunk_end;
101
102 let Some(buf_to_read) = slice_from(chunk_start, buf) else {
103 break Chunked {
104 remaining: (chunk_start - buf.len() as u64),
105 total,
106 };
107 };
108
109 if buf_to_read.is_empty() {
110 break Chunked {
111 remaining: (chunk_start - buf.len() as u64),
112 total,
113 };
114 }
115
116 match httparse::parse_chunk_size(buf_to_read) {
117 Ok(Status::Complete((framing_bytes, chunk_size))) => {
118 chunk_start += framing_bytes as u64;
119 chunk_end = (2 + chunk_start)
120 .checked_add(chunk_size)
121 .ok_or_else(|| io::Error::new(ErrorKind::InvalidData, "chunk size too long"))?;
122
123 if chunk_size == 0 {
124 if let Some(buf) = slice_from(chunk_end, buf) {
125 self_buffer.extend_from_slice(buf);
126 }
127 break End;
128 }
129 }
130
131 Ok(Status::Partial) => {
132 self_buffer.extend_from_slice(buf_to_read);
133 break PartialChunkSize { total };
134 }
135
136 Err(InvalidChunkSize) => {
137 return Err(io::Error::new(ErrorKind::InvalidData, "invalid chunk size"));
138 }
139 }
140 };
141
142 let mut bytes = 0;
143
144 for range_to_keep in ranges_to_keep {
145 let new_bytes = bytes + range_to_keep.end - range_to_keep.start;
146 buf.copy_within(range_to_keep, bytes);
147 bytes = new_bytes;
148 }
149
150 Ok((request_body_state, bytes))
151}
152
153#[cfg(test)]
154mod tests {
155 use super::{chunk_decode, ReceivedBody, ReceivedBodyState};
156 use crate::{http_config::DEFAULT_CONFIG, Buffer, HttpConfig};
157 use encoding_rs::UTF_8;
158 use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt};
159 use trillium_testing::block_on;
160
161 #[track_caller]
162 fn assert_decoded(
163 (remaining, input_data): (u64, &str),
164 expected_output: (Option<u64>, &str, &str),
165 ) {
166 let mut buf = input_data.to_string().into_bytes();
167 let mut self_buf = Buffer::with_capacity(100);
168
169 let (output_state, bytes) = chunk_decode(
170 &mut self_buf,
171 remaining,
172 0,
173 &mut buf,
174 DEFAULT_CONFIG.received_body_max_len,
175 )
176 .unwrap();
177
178 assert_eq!(
179 (
180 match output_state {
181 ReceivedBodyState::Chunked { remaining, .. } => Some(remaining),
182 ReceivedBodyState::PartialChunkSize { .. } => Some(0),
183 ReceivedBodyState::End => None,
184 _ => panic!("unexpected output state {output_state:?}"),
185 },
186 &*String::from_utf8_lossy(&buf[0..bytes]),
187 &*String::from_utf8_lossy(&self_buf)
188 ),
189 expected_output
190 );
191 }
192
193 async fn read_with_buffers_of_size<R>(reader: &mut R, size: usize) -> crate::Result<String>
194 where
195 R: AsyncRead + Unpin,
196 {
197 let mut return_buffer = vec![];
198 loop {
199 let mut buf = vec![0; size];
200 match reader.read(&mut buf).await? {
201 0 => break Ok(String::from_utf8_lossy(&return_buffer).into()),
202 bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]),
203 }
204 }
205 }
206
207 fn new_with_config(input: String, config: &HttpConfig) -> ReceivedBody<'_, Cursor<String>> {
208 ReceivedBody::new_with_config(
209 None,
210 Buffer::from(Vec::with_capacity(config.response_header_initial_capacity)),
211 Cursor::new(input),
212 ReceivedBodyState::Start,
213 None,
214 UTF_8,
215 config,
216 )
217 }
218
219 async fn decode_with_config(
220 input: String,
221 poll_size: usize,
222 config: &HttpConfig,
223 ) -> crate::Result<String> {
224 let mut rb = new_with_config(input, config);
225 read_with_buffers_of_size(&mut rb, poll_size).await
226 }
227
228 async fn decode(input: String, poll_size: usize) -> crate::Result<String> {
229 decode_with_config(input, poll_size, &DEFAULT_CONFIG).await
230 }
231
232 #[test]
233 fn test_full_decode() {
234 block_on(async {
235 for size in 1..50 {
236 let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n";
237 let output = decode(input.into(), size).await.unwrap();
238 assert_eq!(output, "12345abcdef", "size: {size}");
239
240 let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n";
241 let output = decode(input.into(), size).await.unwrap();
242 assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}");
243
244 assert!(decode(String::new(), size).await.is_err());
245 assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err());
246 }
247 });
248 }
249
250 async fn build_chunked_body(input: String) -> String {
251 let mut output = Vec::with_capacity(10);
252 let len = crate::copy(
253 crate::Body::new_streaming(Cursor::new(input), None),
254 &mut output,
255 16,
256 )
257 .await
258 .unwrap();
259
260 output.truncate(len.try_into().unwrap());
261 String::from_utf8(output).unwrap()
262 }
263
264 #[test]
265 fn test_read_buffer_short() {
266 block_on(async {
267 let input = "test ".repeat(50);
268 let chunked = build_chunked_body(input.clone()).await;
269
270 for size in 1..10 {
271 assert_eq!(
272 &decode(chunked.clone(), size).await.unwrap(),
273 &input,
274 "size: {size}"
275 );
276 }
277 });
278 }
279
280 #[test]
281 fn test_max_len() {
282 block_on(async {
283 let input = build_chunked_body("test ".repeat(10)).await;
284
285 for size in 4..10 {
286 assert!(decode_with_config(
287 input.clone(),
288 size,
289 &HttpConfig::default().with_received_body_max_len(5)
290 )
291 .await
292 .is_err());
293
294 assert!(
295 decode_with_config(input.clone(), size, &HttpConfig::default())
296 .await
297 .is_ok()
298 );
299 }
300 });
301 }
302
303 #[test]
304 fn test_chunk_start() {
305 assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", ""));
306 assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", ""));
307 assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", ""));
308 assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n"), (Some(0), "XX", ""));
309 assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n1"), (Some(0), "XX", "1"));
310 assert_decoded((0, "FFF\r\n"), (Some(0xfff + 2), "", ""));
311 assert_decoded((10, "hello"), (Some(5), "hello", ""));
312 assert_decoded(
313 (7, "hello\r\nA\r\n world"),
314 (Some(4 + 2), "hello world", ""),
315 );
316 assert_decoded(
317 (0, "e\r\ntest test test\r\n0\r\n\r\n"),
318 (None, "test test test", ""),
319 );
320 assert_decoded(
321 (0, "1\r\n_\r\n0\r\n\r\nnext request"),
322 (None, "_", "next request"),
323 );
324 assert_decoded((7, "hello\r\n0\r\n\r\n"), (None, "hello", ""));
325 }
326
327 #[test]
328 fn read_string_and_read_bytes() {
329 block_on(async {
330 let content = build_chunked_body("test ".repeat(100)).await;
331 assert_eq!(
332 new_with_config(content.clone(), &DEFAULT_CONFIG)
333 .read_string()
334 .await
335 .unwrap()
336 .len(),
337 500
338 );
339
340 assert_eq!(
341 new_with_config(content.clone(), &DEFAULT_CONFIG)
342 .read_bytes()
343 .await
344 .unwrap()
345 .len(),
346 500
347 );
348
349 assert!(new_with_config(
350 content.clone(),
351 &DEFAULT_CONFIG.with_received_body_max_len(400)
352 )
353 .read_string()
354 .await
355 .is_err());
356
357 assert!(new_with_config(
358 content.clone(),
359 &DEFAULT_CONFIG.with_received_body_max_len(400)
360 )
361 .read_bytes()
362 .await
363 .is_err());
364
365 assert!(new_with_config(content.clone(), &DEFAULT_CONFIG)
366 .with_max_len(400)
367 .read_bytes()
368 .await
369 .is_err());
370
371 assert!(new_with_config(content.clone(), &DEFAULT_CONFIG)
372 .with_max_len(400)
373 .read_string()
374 .await
375 .is_err());
376 });
377 }
378}