xitca_http/h1/proto/
decode.rs

1use core::mem::MaybeUninit;
2
3use httparse::Status;
4
5use crate::{
6    bytes::{Buf, Bytes, BytesMut},
7    http::{
8        Extension, Method, Request, RequestExt, Uri, Version,
9        header::{CONNECTION, CONTENT_LENGTH, EXPECT, HeaderMap, HeaderName, HeaderValue, TRANSFER_ENCODING, UPGRADE},
10    },
11};
12
13use super::{
14    codec::TransferCoding,
15    context::Context,
16    error::ProtoError,
17    header::{self, HeaderIndex},
18};
19
20type Decoded = (Request<RequestExt<()>>, TransferCoding);
21
22impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
23    // decode head and generate request and body decoder.
24    pub fn decode_head<const READ_BUF_LIMIT: usize>(
25        &mut self,
26        buf: &mut BytesMut,
27    ) -> Result<Option<Decoded>, ProtoError> {
28        let mut req = httparse::Request::new(&mut []);
29        let mut headers = [const { MaybeUninit::uninit() }; MAX_HEADERS];
30
31        match req.parse_with_uninit_headers(buf, &mut headers)? {
32            Status::Complete(len) => {
33                // Important: reset context state for new request.
34                self.reset();
35
36                let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
37
38                // default body decoder from method.
39                let mut decoder = match method {
40                    // set method to context so it can pass method to response.
41                    Method::CONNECT => {
42                        self.set_connect_method();
43                        TransferCoding::upgrade()
44                    }
45                    Method::HEAD => {
46                        self.set_head_method();
47                        TransferCoding::eof()
48                    }
49                    _ => TransferCoding::eof(),
50                };
51
52                // Set connection type when doing version match.
53                let version = if req.version.unwrap() == 1 {
54                    // Default ctype is KeepAlive so set_ctype is skipped here.
55                    Version::HTTP_11
56                } else {
57                    self.set_close();
58                    Version::HTTP_10
59                };
60
61                // record indices of headers from bytes buffer.
62                let mut header_idx = [const { MaybeUninit::uninit() }; MAX_HEADERS];
63                let header_idx_slice = HeaderIndex::record(&mut header_idx, buf, req.headers);
64                let headers_len = req.headers.len();
65
66                // record indices of request path from buffer.
67                let path = req.path.unwrap();
68                let path_head = path.as_ptr() as usize - buf.as_ptr() as usize;
69                let path_len = path.len();
70
71                // split the headers from buffer.
72                let slice = buf.split_to(len).freeze();
73
74                let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
75
76                // pop a cached headermap or construct a new one.
77                let mut headers = self.take_headers();
78                headers.reserve(headers_len);
79
80                // write headers to headermap and update request states.
81                for idx in header_idx_slice {
82                    self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
83                }
84
85                let ext = Extension::new(*self.socket_addr());
86                let mut req = Request::new(RequestExt::from_parts((), ext));
87
88                let extensions = self.take_extensions();
89
90                *req.method_mut() = method;
91                *req.version_mut() = version;
92                *req.uri_mut() = uri;
93                *req.headers_mut() = headers;
94                *req.extensions_mut() = extensions;
95
96                Ok(Some((req, decoder)))
97            }
98
99            Status::Partial => {
100                if buf.remaining() >= READ_BUF_LIMIT {
101                    Err(ProtoError::HeaderTooLarge)
102                } else {
103                    Ok(None)
104                }
105            }
106        }
107    }
108
109    pub fn try_write_header(
110        &mut self,
111        headers: &mut HeaderMap,
112        decoder: &mut TransferCoding,
113        idx: &HeaderIndex,
114        slice: &Bytes,
115        version: Version,
116    ) -> Result<(), ProtoError> {
117        let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap();
118        let value = HeaderValue::from_maybe_shared(slice.slice(idx.value.0..idx.value.1)).unwrap();
119
120        match name {
121            TRANSFER_ENCODING => {
122                if version != Version::HTTP_11 {
123                    return Err(ProtoError::HeaderName);
124                }
125                for val in value.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
126                    let val = val.trim();
127                    if val.eq_ignore_ascii_case("chunked") {
128                        decoder.try_set(TransferCoding::decode_chunked())?;
129                    }
130                }
131            }
132            CONTENT_LENGTH => {
133                let len = header::parse_content_length(&value)?;
134                decoder.try_set(TransferCoding::length(len))?;
135            }
136            CONNECTION => self.try_set_close_from_header(&value)?,
137            EXPECT => {
138                if !value.as_bytes().eq_ignore_ascii_case(b"100-continue") {
139                    return Err(ProtoError::HeaderValue);
140                }
141                self.set_expect_header()
142            }
143            UPGRADE => {
144                if version != Version::HTTP_11 {
145                    return Err(ProtoError::HeaderName);
146                }
147                decoder.try_set(TransferCoding::upgrade())?;
148            }
149            _ => {}
150        }
151
152        headers.append(name, value);
153
154        Ok(())
155    }
156
157    pub(super) fn try_set_close_from_header(&mut self, val: &HeaderValue) -> Result<(), ProtoError> {
158        for val in val.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
159            let val = val.trim();
160            if val.eq_ignore_ascii_case("keep-alive") {
161                self.remove_close()
162            } else if val.eq_ignore_ascii_case("close") {
163                self.set_close()
164            }
165        }
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod test {
172    use super::*;
173
174    #[test]
175    fn connection_multiple_value() {
176        let mut ctx = Context::<_, 4>::new(&());
177
178        let head = b"\
179                GET / HTTP/1.1\r\n\
180                Connection: keep-alive, upgrade\r\n\
181                \r\n\
182                ";
183        let mut buf = BytesMut::from(&head[..]);
184
185        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
186        assert!(!ctx.is_connection_closed());
187
188        // this is a wrong connection header but instead of rejecting it the last close value is
189        // used for the final value. there is no particular reason behind this behaviour and this
190        // session of the test serves only as a consistency check to prevent regression.
191        let head = b"\
192                GET / HTTP/1.1\r\n\
193                Connection: keep-alive, close, upgrade\r\n\
194                \r\n\
195                ";
196        let mut buf = BytesMut::from(&head[..]);
197
198        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
199        assert!(ctx.is_connection_closed());
200
201        let head = b"\
202                GET / HTTP/1.1\r\n\
203                Connection: close, keep-alive, upgrade\r\n\
204                \r\n\
205                ";
206        let mut buf = BytesMut::from(&head[..]);
207
208        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
209        assert!(!ctx.is_connection_closed());
210    }
211
212    #[test]
213    fn transfer_encoding() {
214        let mut ctx = Context::<_, 4>::new(&());
215
216        let head = b"\
217                GET / HTTP/1.1\r\n\
218                Transfer-Encoding: gzip\r\n\
219                Transfer-Encoding: chunked\r\n\
220                \r\n\
221                ";
222        let mut buf = BytesMut::from(&head[..]);
223
224        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
225        let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
226        assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
227        assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
228        assert!(
229            matches!(decoder, TransferCoding::DecodeChunked(..)),
230            "transfer coding is not decoded to chunked"
231        );
232
233        ctx.reset();
234
235        let head = b"\
236        GET / HTTP/1.1\r\n\
237        Transfer-Encoding: chunked\r\n\
238        Transfer-Encoding: gzip\r\n\
239        \r\n\
240        ";
241        let mut buf = BytesMut::from(&head[..]);
242
243        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
244        let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
245        assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
246        assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
247        assert!(
248            matches!(decoder, TransferCoding::DecodeChunked(..)),
249            "transfer coding is not decoded to chunked"
250        );
251
252        ctx.reset();
253
254        let head = b"\
255                GET / HTTP/1.1\r\n\
256                Transfer-Encoding: gzip, chunked\r\n\
257                \r\n\
258                ";
259        let mut buf = BytesMut::from(&head[..]);
260
261        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
262        assert_eq!(
263            req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
264            "gzip, chunked"
265        );
266        assert!(
267            matches!(decoder, TransferCoding::DecodeChunked(..)),
268            "transfer coding is not decoded to chunked"
269        );
270
271        ctx.reset();
272
273        let head = b"\
274                GET / HTTP/1.1\r\n\
275                Transfer-Encoding: chunked\r\n\
276                \r\n\
277                ";
278        let mut buf = BytesMut::from(&head[..]);
279
280        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
281        assert_eq!(
282            req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
283            "chunked"
284        );
285        assert!(
286            matches!(decoder, TransferCoding::DecodeChunked(..)),
287            "transfer coding is not decoded to chunked"
288        );
289
290        let head = b"\
291                GET / HTTP/1.1\r\n\
292                Transfer-Encoding: identity\r\n\
293                \r\n\
294                ";
295        let mut buf = BytesMut::from(&head[..]);
296        let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
297        assert!(
298            matches!(decoder, TransferCoding::Eof),
299            "transfer coding is not decoded to eof"
300        );
301
302        let head = b"\
303                GET / HTTP/1.1\r\n\
304                Transfer-Encoding: chunked, gzip\r\n\
305                \r\n\
306                ";
307        let mut buf = BytesMut::from(&head[..]);
308        let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
309        assert!(
310            matches!(decoder, TransferCoding::DecodeChunked(..)),
311            "transfer coding is not decoded to chunked"
312        );
313    }
314}