trillium_http/received_body/
chunked.rs

1use super::{
2    io, ready, slice_from, AsyncRead, Buffer, Chunked, Context, End, ErrorKind, InvalidChunkSize,
3    PartialChunkSize, Pin, Ready, ReceivedBody, ReceivedBodyState, StateOutput, Status,
4};
5
6impl<'conn, Transport> ReceivedBody<'conn, Transport>
7where
8    Transport: AsyncRead + Unpin + Send + Sync + 'static,
9{
10    #[inline]
11    pub(super) fn handle_chunked(
12        &mut self,
13        cx: &mut Context<'_>,
14        buf: &mut [u8],
15        remaining: u64,
16        total: u64,
17    ) -> StateOutput {
18        let bytes = ready!(self.read_raw(cx, buf)?);
19
20        Ready(chunk_decode(
21            &mut self.buffer,
22            remaining,
23            total,
24            &mut buf[..bytes],
25            self.max_len,
26        ))
27    }
28
29    #[inline]
30    pub(super) fn handle_partial(
31        &mut self,
32        cx: &mut Context<'_>,
33        buf: &mut [u8],
34        total: u64,
35    ) -> StateOutput {
36        let transport = self
37            .transport
38            .as_deref_mut()
39            .ok_or_else(|| io::Error::from(ErrorKind::NotConnected))?;
40        let bytes = ready!(Pin::new(transport).poll_read(cx, buf))?;
41
42        if bytes == 0 {
43            return Ready(Err(io::Error::from(ErrorKind::ConnectionAborted)));
44        }
45
46        self.buffer.extend_from_slice(&buf[..bytes]);
47
48        match httparse::parse_chunk_size(&self.buffer) {
49            Ok(Status::Complete((framing_bytes, remaining))) => {
50                self.buffer.ignore_front(framing_bytes);
51                Ready(Ok((
52                    if remaining == 0 {
53                        End
54                    } else {
55                        Chunked {
56                            remaining: remaining + 2,
57                            total,
58                        }
59                    },
60                    0,
61                )))
62            }
63
64            Ok(Status::Partial) => Ready(Ok((PartialChunkSize { total }, 0))),
65
66            Err(InvalidChunkSize) => Ready(Err(io::Error::new(
67                ErrorKind::InvalidData,
68                "invalid chunk framing",
69            ))),
70        }
71    }
72}
73
74pub(super) fn chunk_decode(
75    self_buffer: &mut Buffer,
76    remaining: u64,
77    mut total: u64,
78    buf: &mut [u8],
79    max_len: u64,
80) -> io::Result<(ReceivedBodyState, usize)> {
81    if buf.is_empty() {
82        return Err(io::Error::from(ErrorKind::ConnectionAborted));
83    }
84    let mut ranges_to_keep = vec![];
85    let mut chunk_start = 0u64;
86    let mut chunk_end = remaining;
87    let request_body_state = loop {
88        if chunk_end > 2 {
89            let keep_start = usize::try_from(chunk_start).unwrap_or(usize::MAX);
90            let keep_end = buf
91                .len()
92                .min(usize::try_from(chunk_end - 2).unwrap_or(usize::MAX));
93            ranges_to_keep.push(keep_start..keep_end);
94            let new_bytes = (keep_end - keep_start) as u64;
95            total += new_bytes;
96            if total > max_len {
97                return Err(io::Error::new(ErrorKind::Unsupported, "content too long"));
98            }
99        }
100        chunk_start = chunk_end;
101
102        let Some(buf_to_read) = slice_from(chunk_start, buf) else {
103            break Chunked {
104                remaining: (chunk_start - buf.len() as u64),
105                total,
106            };
107        };
108
109        if buf_to_read.is_empty() {
110            break Chunked {
111                remaining: (chunk_start - buf.len() as u64),
112                total,
113            };
114        }
115
116        match httparse::parse_chunk_size(buf_to_read) {
117            Ok(Status::Complete((framing_bytes, chunk_size))) => {
118                chunk_start += framing_bytes as u64;
119                chunk_end = (2 + chunk_start)
120                    .checked_add(chunk_size)
121                    .ok_or_else(|| io::Error::new(ErrorKind::InvalidData, "chunk size too long"))?;
122
123                if chunk_size == 0 {
124                    if let Some(buf) = slice_from(chunk_end, buf) {
125                        self_buffer.extend_from_slice(buf);
126                    }
127                    break End;
128                }
129            }
130
131            Ok(Status::Partial) => {
132                self_buffer.extend_from_slice(buf_to_read);
133                break PartialChunkSize { total };
134            }
135
136            Err(InvalidChunkSize) => {
137                return Err(io::Error::new(ErrorKind::InvalidData, "invalid chunk size"));
138            }
139        }
140    };
141
142    let mut bytes = 0;
143
144    for range_to_keep in ranges_to_keep {
145        let new_bytes = bytes + range_to_keep.end - range_to_keep.start;
146        buf.copy_within(range_to_keep, bytes);
147        bytes = new_bytes;
148    }
149
150    Ok((request_body_state, bytes))
151}
152
153#[cfg(test)]
154mod tests {
155    use super::{chunk_decode, ReceivedBody, ReceivedBodyState};
156    use crate::{http_config::DEFAULT_CONFIG, Buffer, HttpConfig};
157    use encoding_rs::UTF_8;
158    use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt};
159    use trillium_testing::block_on;
160
161    #[track_caller]
162    fn assert_decoded(
163        (remaining, input_data): (u64, &str),
164        expected_output: (Option<u64>, &str, &str),
165    ) {
166        let mut buf = input_data.to_string().into_bytes();
167        let mut self_buf = Buffer::with_capacity(100);
168
169        let (output_state, bytes) = chunk_decode(
170            &mut self_buf,
171            remaining,
172            0,
173            &mut buf,
174            DEFAULT_CONFIG.received_body_max_len,
175        )
176        .unwrap();
177
178        assert_eq!(
179            (
180                match output_state {
181                    ReceivedBodyState::Chunked { remaining, .. } => Some(remaining),
182                    ReceivedBodyState::PartialChunkSize { .. } => Some(0),
183                    ReceivedBodyState::End => None,
184                    _ => panic!("unexpected output state {output_state:?}"),
185                },
186                &*String::from_utf8_lossy(&buf[0..bytes]),
187                &*String::from_utf8_lossy(&self_buf)
188            ),
189            expected_output
190        );
191    }
192
193    async fn read_with_buffers_of_size<R>(reader: &mut R, size: usize) -> crate::Result<String>
194    where
195        R: AsyncRead + Unpin,
196    {
197        let mut return_buffer = vec![];
198        loop {
199            let mut buf = vec![0; size];
200            match reader.read(&mut buf).await? {
201                0 => break Ok(String::from_utf8_lossy(&return_buffer).into()),
202                bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]),
203            }
204        }
205    }
206
207    fn new_with_config(input: String, config: &HttpConfig) -> ReceivedBody<'_, Cursor<String>> {
208        ReceivedBody::new_with_config(
209            None,
210            Buffer::from(Vec::with_capacity(config.response_header_initial_capacity)),
211            Cursor::new(input),
212            ReceivedBodyState::Start,
213            None,
214            UTF_8,
215            config,
216        )
217    }
218
219    async fn decode_with_config(
220        input: String,
221        poll_size: usize,
222        config: &HttpConfig,
223    ) -> crate::Result<String> {
224        let mut rb = new_with_config(input, config);
225        read_with_buffers_of_size(&mut rb, poll_size).await
226    }
227
228    async fn decode(input: String, poll_size: usize) -> crate::Result<String> {
229        decode_with_config(input, poll_size, &DEFAULT_CONFIG).await
230    }
231
232    #[test]
233    fn test_full_decode() {
234        block_on(async {
235            for size in 1..50 {
236                let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n";
237                let output = decode(input.into(), size).await.unwrap();
238                assert_eq!(output, "12345abcdef", "size: {size}");
239
240                let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n";
241                let output = decode(input.into(), size).await.unwrap();
242                assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}");
243
244                assert!(decode(String::new(), size).await.is_err());
245                assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err());
246            }
247        });
248    }
249
250    async fn build_chunked_body(input: String) -> String {
251        let mut output = Vec::with_capacity(10);
252        let len = crate::copy(
253            crate::Body::new_streaming(Cursor::new(input), None),
254            &mut output,
255            16,
256        )
257        .await
258        .unwrap();
259
260        output.truncate(len.try_into().unwrap());
261        String::from_utf8(output).unwrap()
262    }
263
264    #[test]
265    fn test_read_buffer_short() {
266        block_on(async {
267            let input = "test ".repeat(50);
268            let chunked = build_chunked_body(input.clone()).await;
269
270            for size in 1..10 {
271                assert_eq!(
272                    &decode(chunked.clone(), size).await.unwrap(),
273                    &input,
274                    "size: {size}"
275                );
276            }
277        });
278    }
279
280    #[test]
281    fn test_max_len() {
282        block_on(async {
283            let input = build_chunked_body("test ".repeat(10)).await;
284
285            for size in 4..10 {
286                assert!(decode_with_config(
287                    input.clone(),
288                    size,
289                    &HttpConfig::default().with_received_body_max_len(5)
290                )
291                .await
292                .is_err());
293
294                assert!(
295                    decode_with_config(input.clone(), size, &HttpConfig::default())
296                        .await
297                        .is_ok()
298                );
299            }
300        });
301    }
302
303    #[test]
304    fn test_chunk_start() {
305        assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", ""));
306        assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", ""));
307        assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", ""));
308        assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n"), (Some(0), "XX", ""));
309        assert_decoded((0, "1\r\nX\r\n1\r\nX\r\n1"), (Some(0), "XX", "1"));
310        assert_decoded((0, "FFF\r\n"), (Some(0xfff + 2), "", ""));
311        assert_decoded((10, "hello"), (Some(5), "hello", ""));
312        assert_decoded(
313            (7, "hello\r\nA\r\n world"),
314            (Some(4 + 2), "hello world", ""),
315        );
316        assert_decoded(
317            (0, "e\r\ntest test test\r\n0\r\n\r\n"),
318            (None, "test test test", ""),
319        );
320        assert_decoded(
321            (0, "1\r\n_\r\n0\r\n\r\nnext request"),
322            (None, "_", "next request"),
323        );
324        assert_decoded((7, "hello\r\n0\r\n\r\n"), (None, "hello", ""));
325    }
326
327    #[test]
328    fn read_string_and_read_bytes() {
329        block_on(async {
330            let content = build_chunked_body("test ".repeat(100)).await;
331            assert_eq!(
332                new_with_config(content.clone(), &DEFAULT_CONFIG)
333                    .read_string()
334                    .await
335                    .unwrap()
336                    .len(),
337                500
338            );
339
340            assert_eq!(
341                new_with_config(content.clone(), &DEFAULT_CONFIG)
342                    .read_bytes()
343                    .await
344                    .unwrap()
345                    .len(),
346                500
347            );
348
349            assert!(new_with_config(
350                content.clone(),
351                &DEFAULT_CONFIG.with_received_body_max_len(400)
352            )
353            .read_string()
354            .await
355            .is_err());
356
357            assert!(new_with_config(
358                content.clone(),
359                &DEFAULT_CONFIG.with_received_body_max_len(400)
360            )
361            .read_bytes()
362            .await
363            .is_err());
364
365            assert!(new_with_config(content.clone(), &DEFAULT_CONFIG)
366                .with_max_len(400)
367                .read_bytes()
368                .await
369                .is_err());
370
371            assert!(new_with_config(content.clone(), &DEFAULT_CONFIG)
372                .with_max_len(400)
373                .read_string()
374                .await
375                .is_err());
376        });
377    }
378}