xitca_http/h1/proto/
codec.rs

1use core::{fmt, mem};
2
3use std::io;
4
5use tracing::{trace, warn};
6
7use crate::bytes::{Buf, Bytes, BytesMut};
8
9use super::{buf_write::H1BufWrite, error::ProtoError};
10
11/// Coder for different Transfer-Decoding/Transfer-Encoding.
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub enum TransferCoding {
14    /// Default coder indicates the Request/Response have successfully code it's associated body.
15    Eof,
16    /// Corrupted coder that can not be used anymore.
17    Corrupted,
18    /// Coder used when a Content-Length header is passed with a positive integer.
19    Length(u64),
20    /// Decoder used when Transfer-Encoding is `chunked`.
21    DecodeChunked(ChunkedState, u64),
22    /// Encoder for when Transfer-Encoding includes `chunked`.
23    EncodeChunked,
24    /// Upgrade type coder that pass through body as is without transforming.
25    Upgrade,
26}
27
28impl TransferCoding {
29    #[inline]
30    pub const fn eof() -> Self {
31        Self::Eof
32    }
33
34    #[inline]
35    pub const fn length(len: u64) -> Self {
36        Self::Length(len)
37    }
38
39    #[inline]
40    pub const fn decode_chunked() -> Self {
41        Self::DecodeChunked(ChunkedState::Size, 0)
42    }
43
44    #[inline]
45    pub const fn encode_chunked() -> Self {
46        Self::EncodeChunked
47    }
48
49    #[inline]
50    pub const fn upgrade() -> Self {
51        Self::Upgrade
52    }
53
54    /// Check if Self is in EOF state. An EOF state means TransferCoding is ended gracefully
55    /// and can not decode any value. See [TransferCoding::decode] for detail.
56    #[inline]
57    pub fn is_eof(&self) -> bool {
58        match self {
59            Self::Eof => true,
60            Self::EncodeChunked => unreachable!("TransferCoding can't decide eof state when encoding chunked data"),
61            _ => false,
62        }
63    }
64
65    #[inline]
66    pub fn is_upgrade(&self) -> bool {
67        matches!(self, Self::Upgrade)
68    }
69}
70
71#[derive(Clone, Debug, Eq, PartialEq)]
72pub enum ChunkedState {
73    Size,
74    SizeLws,
75    Extension,
76    SizeLf,
77    Body,
78    BodyCr,
79    BodyLf,
80    Trailer,
81    TrailerLf,
82    EndCr,
83    EndLf,
84    End,
85}
86
87macro_rules! byte (
88    ($rdr:ident) => ({
89        if $rdr.len() > 0 {
90            let b = $rdr[0];
91            $rdr.advance(1);
92            b
93        } else {
94            return Ok(None);
95        }
96    })
97);
98
99impl ChunkedState {
100    pub fn step(&mut self, body: &mut BytesMut, size: &mut u64, buf: &mut Option<Bytes>) -> io::Result<Option<Self>> {
101        match *self {
102            Self::Size => Self::read_size(body, size),
103            Self::SizeLws => Self::read_size_lws(body),
104            Self::Extension => Self::read_extension(body),
105            Self::SizeLf => Self::read_size_lf(body, size),
106            Self::Body => Self::read_body(body, size, buf),
107            Self::BodyCr => Self::read_body_cr(body),
108            Self::BodyLf => Self::read_body_lf(body),
109            Self::Trailer => Self::read_trailer(body),
110            Self::TrailerLf => Self::read_trailer_lf(body),
111            Self::EndCr => Self::read_end_cr(body),
112            Self::EndLf => Self::read_end_lf(body),
113            Self::End => Ok(Some(Self::End)),
114        }
115    }
116
117    fn read_size(rdr: &mut BytesMut, size: &mut u64) -> io::Result<Option<Self>> {
118        macro_rules! or_overflow {
119            ($e:expr) => (
120                match $e {
121                    Some(val) => val,
122                    None => return Err(io::Error::new(
123                        io::ErrorKind::InvalidData,
124                        "invalid chunk size: overflow",
125                    )),
126                }
127            )
128        }
129
130        let radix = 16;
131        match byte!(rdr) {
132            b @ b'0'..=b'9' => {
133                *size = or_overflow!(size.checked_mul(radix));
134                *size = or_overflow!(size.checked_add((b - b'0') as u64));
135            }
136            b @ b'a'..=b'f' => {
137                *size = or_overflow!(size.checked_mul(radix));
138                *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
139            }
140            b @ b'A'..=b'F' => {
141                *size = or_overflow!(size.checked_mul(radix));
142                *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
143            }
144            b'\t' | b' ' => return Ok(Some(ChunkedState::SizeLws)),
145            b';' => return Ok(Some(ChunkedState::Extension)),
146            b'\r' => return Ok(Some(ChunkedState::SizeLf)),
147            _ => {
148                return Err(io::Error::new(
149                    io::ErrorKind::InvalidInput,
150                    "Invalid chunk size line: Invalid Size",
151                ));
152            }
153        }
154
155        Ok(Some(ChunkedState::Size))
156    }
157
158    fn read_size_lws(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
159        match byte!(rdr) {
160            // LWS can follow the chunk size, but no more digits can come
161            b'\t' | b' ' => Ok(Some(Self::SizeLws)),
162            b';' => Ok(Some(Self::Extension)),
163            b'\r' => Ok(Some(Self::SizeLf)),
164            _ => Err(io::Error::new(
165                io::ErrorKind::InvalidInput,
166                "Invalid chunk size linear white space",
167            )),
168        }
169    }
170
171    fn read_extension(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
172        match byte!(rdr) {
173            b'\r' => Ok(Some(Self::SizeLf)),
174            b'\n' => Err(io::Error::new(
175                io::ErrorKind::InvalidData,
176                "invalid chunk extension contains newline",
177            )),
178            _ => Ok(Some(Self::Extension)), // no supported extensions
179        }
180    }
181
182    fn read_size_lf(rdr: &mut BytesMut, size: &u64) -> io::Result<Option<Self>> {
183        match byte!(rdr) {
184            b'\n' if *size > 0 => Ok(Some(Self::Body)),
185            b'\n' if *size == 0 => Ok(Some(Self::EndCr)),
186            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")),
187        }
188    }
189
190    fn read_body(rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option<Bytes>) -> io::Result<Option<Self>> {
191        if rdr.is_empty() {
192            Ok(None)
193        } else {
194            *buf = Some(bounded_split(rem, rdr));
195            if *rem > 0 {
196                Ok(Some(Self::Body))
197            } else {
198                Ok(Some(Self::BodyCr))
199            }
200        }
201    }
202
203    fn read_body_cr(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
204        match byte!(rdr) {
205            b'\r' => Ok(Some(Self::BodyLf)),
206            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")),
207        }
208    }
209
210    fn read_body_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
211        match byte!(rdr) {
212            b'\n' => Ok(Some(Self::Size)),
213            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")),
214        }
215    }
216
217    fn read_trailer(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
218        trace!(target: "h1_decode", "read_trailer");
219        match byte!(rdr) {
220            b'\r' => Ok(Some(Self::TrailerLf)),
221            _ => Ok(Some(Self::Trailer)),
222        }
223    }
224
225    fn read_trailer_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
226        match byte!(rdr) {
227            b'\n' => Ok(Some(Self::EndCr)),
228            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid trailer end LF")),
229        }
230    }
231
232    fn read_end_cr(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
233        match byte!(rdr) {
234            b'\r' => Ok(Some(Self::EndLf)),
235            _ => Ok(Some(Self::Trailer)),
236        }
237    }
238
239    fn read_end_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
240        match byte!(rdr) {
241            b'\n' => Ok(Some(Self::End)),
242            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")),
243        }
244    }
245}
246
247impl TransferCoding {
248    pub fn try_set(&mut self, other: Self) -> Result<(), ProtoError> {
249        match (&self, &other) {
250            // multiple set to plain chunked is allowed. This can happen from Connect method
251            // and/or Connection header.
252            // skip set when the request body is zero length.
253            (TransferCoding::Upgrade, TransferCoding::Upgrade) | (_, TransferCoding::Length(0)) => Ok(()),
254            // multiple set to decoded chunked/content-length are forbidden.
255            // mutation between decoded chunked/content-length/plain chunked is forbidden.
256            (TransferCoding::Upgrade, _) | (TransferCoding::DecodeChunked(..), _) | (TransferCoding::Length(..), _) => {
257                Err(ProtoError::HeaderName)
258            }
259            _ => {
260                *self = other;
261                Ok(())
262            }
263        }
264    }
265
266    #[inline]
267    pub fn set_eof(&mut self) {
268        *self = Self::Eof;
269    }
270
271    #[inline]
272    pub fn set_corrupted(&mut self) {
273        *self = Self::Corrupted;
274    }
275
276    /// Encode message. Return `EOF` state of encoder
277    pub fn encode<W>(&mut self, mut bytes: Bytes, buf: &mut W)
278    where
279        W: H1BufWrite,
280    {
281        // Skip encode empty bytes.
282        // This is to avoid unnecessary extending on h1::proto::buf::ListBuf when user
283        // provided empty bytes by accident.
284        if bytes.is_empty() {
285            return;
286        }
287
288        match *self {
289            Self::Upgrade => buf.write_buf_bytes(bytes),
290            Self::EncodeChunked => buf.write_buf_bytes_chunked(bytes),
291            Self::Length(ref mut rem) => {
292                let len = bytes.len() as u64;
293                if *rem >= len {
294                    buf.write_buf_bytes(bytes);
295                    *rem -= len;
296                } else {
297                    let rem = mem::replace(rem, 0u64);
298                    buf.write_buf_bytes(bytes.split_to(rem as usize));
299                }
300            }
301            Self::Eof => warn!(target: "h1_encode", "TransferCoding::Eof should not encode response body"),
302            _ => unreachable!(),
303        }
304    }
305
306    /// Encode eof. Return `EOF` state of encoder
307    pub fn encode_eof<W>(&mut self, buf: &mut W)
308    where
309        W: H1BufWrite,
310    {
311        match *self {
312            Self::Eof | Self::Upgrade | Self::Length(0) => {}
313            Self::EncodeChunked => buf.write_buf_static(b"0\r\n\r\n"),
314            Self::Length(n) => unreachable!("UnexpectedEof for Length Body with {} remaining", n),
315            _ => unreachable!(),
316        }
317    }
318
319    /// decode body. See [ChunkResult] for detailed outcome.
320    pub fn decode(&mut self, src: &mut BytesMut) -> ChunkResult {
321        match *self {
322            // when decoder reaching eof state it would return ChunkResult::Eof and followed by
323            // ChunkResult::AlreadyEof if decode is called again.
324            // This multi stage behaviour is depended on by the caller to know the exact timing of
325            // when eof happens. (Expensive one time operations can be happening at Eof)
326            Self::Length(0) | Self::DecodeChunked(ChunkedState::End, _) => {
327                *self = Self::Eof;
328                ChunkResult::OnEof
329            }
330            Self::Eof => ChunkResult::AlreadyEof,
331            Self::Corrupted => ChunkResult::Corrupted,
332            ref _this if src.is_empty() => ChunkResult::InsufficientData,
333            Self::Length(ref mut rem) => ChunkResult::Ok(bounded_split(rem, src)),
334            Self::Upgrade => ChunkResult::Ok(src.split().freeze()),
335            Self::DecodeChunked(ref mut state, ref mut size) => {
336                loop {
337                    let mut buf = None;
338                    // advances the chunked state
339                    *state = match state.step(src, size, &mut buf) {
340                        Ok(Some(state)) => state,
341                        Ok(None) => return ChunkResult::InsufficientData,
342                        Err(e) => return ChunkResult::Err(e),
343                    };
344
345                    if matches!(state, ChunkedState::End) {
346                        return self.decode(src);
347                    }
348
349                    if let Some(buf) = buf {
350                        return ChunkResult::Ok(buf);
351                    }
352                }
353            }
354            _ => unreachable!(),
355        }
356    }
357}
358
359#[derive(Debug)]
360pub enum ChunkResult {
361    /// non empty chunk data produced by coder.
362    Ok(Bytes),
363    /// io error type produced by coder that can be bubbled up to upstream caller.
364    Err(io::Error),
365    /// insufficient data. More input bytes required.
366    InsufficientData,
367    /// coder reached EOF state and no more chunk can be produced.
368    OnEof,
369    /// coder already reached EOF state and no more chunk can be produced.
370    /// used to hint calling stop filling input buffer with more data and/or calling method again.
371    AlreadyEof,
372    /// see [TransferCoding::Corrupted].
373    Corrupted,
374}
375
376impl fmt::Display for ChunkResult {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        match *self {
379            Self::Ok(_) => f.write_str("chunked data."),
380            Self::Err(ref e) => fmt::Display::fmt(e, f),
381            Self::InsufficientData => f.write_str("no sufficient data. More input bytes required."),
382            Self::OnEof => f.write_str("coder reached EOF state. no more chunk can be produced."),
383            Self::AlreadyEof => f.write_str("coder already reached EOF state. no more chunk can be produced."),
384            Self::Corrupted => f.write_str("coder corrupted. can not be used anymore."),
385        }
386    }
387}
388
389impl From<io::Error> for ChunkResult {
390    fn from(e: io::Error) -> Self {
391        Self::Err(e)
392    }
393}
394
395fn bounded_split(rem: &mut u64, buf: &mut BytesMut) -> Bytes {
396    let len = buf.len() as u64;
397    if *rem >= len {
398        *rem -= len;
399        buf.split().freeze()
400    } else {
401        let rem = mem::replace(rem, 0);
402        buf.split_to(rem as usize).freeze()
403    }
404}
405
406#[cfg(test)]
407mod test {
408    use crate::util::buffered::WriteBuf;
409
410    use super::*;
411
412    #[test]
413    fn test_read_chunk_size() {
414        use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
415
416        fn read(s: &str) -> u64 {
417            let mut state = ChunkedState::Size;
418            let rdr = &mut BytesMut::from(s);
419            let mut size = 0;
420            loop {
421                let result = state.step(rdr, &mut size, &mut None);
422                state = result.unwrap_or_else(|_| panic!("read_size failed for {s:?}")).unwrap();
423                if state == ChunkedState::Body || state == ChunkedState::EndCr {
424                    break;
425                }
426            }
427            size
428        }
429
430        fn read_err(s: &str, expected_err: io::ErrorKind) {
431            let mut state = ChunkedState::Size;
432            let rdr = &mut BytesMut::from(s);
433            let mut size = 0;
434            loop {
435                let result = state.step(rdr, &mut size, &mut None);
436                state = match result {
437                    Ok(Some(s)) => s,
438                    Ok(None) => return assert_eq!(expected_err, UnexpectedEof),
439                    Err(e) => {
440                        assert_eq!(
441                            expected_err,
442                            e.kind(),
443                            "Reading {:?}, expected {:?}, but got {:?}",
444                            s,
445                            expected_err,
446                            e.kind()
447                        );
448                        return;
449                    }
450                };
451                if state == ChunkedState::Body || state == ChunkedState::End {
452                    panic!("Was Ok. Expected Err for {s:?}");
453                }
454            }
455        }
456
457        assert_eq!(1, read("1\r\n"));
458        assert_eq!(1, read("01\r\n"));
459        assert_eq!(0, read("0\r\n"));
460        assert_eq!(0, read("00\r\n"));
461        assert_eq!(10, read("A\r\n"));
462        assert_eq!(10, read("a\r\n"));
463        assert_eq!(255, read("Ff\r\n"));
464        assert_eq!(255, read("Ff   \r\n"));
465        // Missing LF or CRLF
466        read_err("F\rF", InvalidInput);
467        read_err("F", UnexpectedEof);
468        // Invalid hex digit
469        read_err("X\r\n", InvalidInput);
470        read_err("1X\r\n", InvalidInput);
471        read_err("-\r\n", InvalidInput);
472        read_err("-1\r\n", InvalidInput);
473        // Acceptable (if not fully valid) extensions do not influence the size
474        assert_eq!(1, read("1;extension\r\n"));
475        assert_eq!(10, read("a;ext name=value\r\n"));
476        assert_eq!(1, read("1;extension;extension2\r\n"));
477        assert_eq!(1, read("1;;;  ;\r\n"));
478        assert_eq!(2, read("2; extension...\r\n"));
479        assert_eq!(3, read("3   ; extension=123\r\n"));
480        assert_eq!(3, read("3   ;\r\n"));
481        assert_eq!(3, read("3   ;   \r\n"));
482        // Invalid extensions cause an error
483        read_err("1 invalid extension\r\n", InvalidInput);
484        read_err("1 A\r\n", InvalidInput);
485        read_err("1;no CRLF", UnexpectedEof);
486        read_err("1;reject\nnewlines\r\n", InvalidData);
487        // Overflow
488        read_err("f0000000000000003\r\n", InvalidData);
489    }
490
491    #[test]
492    fn test_read_chunked_single_read() {
493        let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n");
494
495        match TransferCoding::decode_chunked().decode(mock_buf) {
496            ChunkResult::Ok(buf) => {
497                assert_eq!(16, buf.len());
498                let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
499                assert_eq!("1234567890abcdef", &result);
500            }
501            state => panic!("{}", state),
502        }
503    }
504
505    #[test]
506    fn test_read_chunked_trailer_with_missing_lf() {
507        let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n");
508
509        let mut decoder = TransferCoding::decode_chunked();
510
511        match decoder.decode(mock_buf) {
512            ChunkResult::Ok(_) => {}
513            state => panic!("{}", state),
514        }
515
516        match decoder.decode(mock_buf) {
517            ChunkResult::Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidInput),
518            state => panic!("{}", state),
519        }
520    }
521
522    #[test]
523    fn test_read_chunked_after_eof() {
524        let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n\r\n");
525        let mut decoder = TransferCoding::decode_chunked();
526
527        // normal read
528        match decoder.decode(mock_buf) {
529            ChunkResult::Ok(buf) => {
530                assert_eq!(16, buf.len());
531                let result = String::from_utf8(buf.as_ref().to_vec()).unwrap();
532                assert_eq!("1234567890abcdef", &result);
533            }
534            state => panic!("{}", state),
535        }
536
537        // eof read
538        match decoder.decode(mock_buf) {
539            ChunkResult::OnEof => {}
540            state => panic!("{}", state),
541        }
542
543        // already meet eof
544        match decoder.decode(mock_buf) {
545            ChunkResult::AlreadyEof => {}
546            state => panic!("{}", state),
547        }
548    }
549
550    #[test]
551    fn encode_chunked() {
552        let mut encoder = TransferCoding::encode_chunked();
553        let dst = &mut WriteBuf::<1024>::default();
554
555        let msg1 = Bytes::from("foo bar");
556        encoder.encode(msg1, dst);
557
558        assert_eq!(dst.buf(), b"7\r\nfoo bar\r\n");
559
560        let msg2 = Bytes::from("baz quux herp");
561        encoder.encode(msg2, dst);
562
563        assert_eq!(dst.buf(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
564
565        encoder.encode_eof(dst);
566
567        assert_eq!(dst.buf(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n");
568    }
569
570    #[test]
571    fn encode_length() {
572        let max_len = 8;
573        let mut encoder = TransferCoding::length(max_len as u64);
574
575        let dst = &mut WriteBuf::<1024>::default();
576
577        let msg1 = Bytes::from("foo bar");
578        encoder.encode(msg1, dst);
579
580        assert_eq!(dst.buf(), b"foo bar");
581
582        for _ in 0..8 {
583            let msg2 = Bytes::from("baz");
584            encoder.encode(msg2, dst);
585
586            assert_eq!(dst.buf().len(), max_len);
587            assert_eq!(dst.buf(), b"foo barb");
588        }
589
590        encoder.encode_eof(dst);
591        assert_eq!(dst.buf().len(), max_len);
592        assert_eq!(dst.buf(), b"foo barb");
593    }
594}