Skip to main content

xitca_http/h1/proto/
decode.rs

1use core::mem::MaybeUninit;
2
3use httparse::Status;
4
5use crate::{
6    bytes::{Buf, Bytes, BytesMut},
7    http::{
8        Extension, Method, Request, RequestExt, Uri, Version,
9        header::{CONNECTION, CONTENT_LENGTH, EXPECT, HeaderMap, HeaderName, HeaderValue, TRANSFER_ENCODING, UPGRADE},
10    },
11};
12
13use super::{
14    context::Context,
15    error::ProtoError,
16    header::{self, HeaderIndex},
17    trasnder_coding::TransferCoding,
18};
19
20type Decoded = (Request<RequestExt<()>>, TransferCoding);
21
22impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
23    // decode head and generate request and body decoder.
24    pub fn decode_head<const READ_BUF_LIMIT: usize>(
25        &mut self,
26        buf: &mut BytesMut,
27    ) -> Result<Option<Decoded>, ProtoError> {
28        let mut req = httparse::Request::new(&mut []);
29        let mut headers = [const { MaybeUninit::uninit() }; MAX_HEADERS];
30
31        match req.parse_with_uninit_headers(buf, &mut headers)? {
32            Status::Complete(len) => {
33                // Important: reset context state for new request.
34                self.reset();
35
36                let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
37
38                // default body decoder from method.
39                let mut decoder = match method {
40                    // set method to context so it can pass method to response.
41                    Method::CONNECT => {
42                        self.set_connect_method();
43                        TransferCoding::upgrade()
44                    }
45                    Method::HEAD => {
46                        self.set_head_method();
47                        TransferCoding::eof()
48                    }
49                    _ => TransferCoding::eof(),
50                };
51
52                // Set connection type when doing version match.
53                let version = if req.version.unwrap() == 1 {
54                    // Default ctype is KeepAlive so set_ctype is skipped here.
55                    Version::HTTP_11
56                } else {
57                    self.set_http10();
58                    self.set_close();
59                    Version::HTTP_10
60                };
61
62                // record indices of headers from bytes buffer.
63                let mut header_idx = [const { MaybeUninit::uninit() }; MAX_HEADERS];
64                let header_idx_slice = HeaderIndex::record(&mut header_idx, buf, req.headers);
65                let headers_len = req.headers.len();
66
67                // record indices of request path from buffer.
68                let path = req.path.unwrap();
69                let path_head = path.as_ptr() as usize - buf.as_ptr() as usize;
70                let path_len = path.len();
71
72                // split the headers from buffer.
73                let slice = buf.split_to(len).freeze();
74
75                let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
76
77                // pop a cached headermap or construct a new one.
78                let mut headers = self.take_headers();
79                headers.reserve(headers_len);
80
81                // write headers to headermap and update request states.
82                for idx in header_idx_slice {
83                    self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
84                }
85
86                let ext = Extension::new(*self.socket_addr());
87                let mut req = Request::new(RequestExt::from_parts((), ext));
88
89                let extensions = self.take_extensions();
90
91                *req.method_mut() = method;
92                *req.version_mut() = version;
93                *req.uri_mut() = uri;
94                *req.headers_mut() = headers;
95                *req.extensions_mut() = extensions;
96
97                Ok(Some((req, decoder)))
98            }
99
100            Status::Partial => {
101                if buf.remaining() >= READ_BUF_LIMIT {
102                    Err(ProtoError::HeaderTooLarge)
103                } else {
104                    Ok(None)
105                }
106            }
107        }
108    }
109
110    pub fn try_write_header(
111        &mut self,
112        headers: &mut HeaderMap,
113        decoder: &mut TransferCoding,
114        idx: &HeaderIndex,
115        slice: &Bytes,
116        version: Version,
117    ) -> Result<(), ProtoError> {
118        let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap();
119        let value = HeaderValue::from_maybe_shared(slice.slice(idx.value.0..idx.value.1)).unwrap();
120
121        match name {
122            TRANSFER_ENCODING => {
123                if version != Version::HTTP_11 {
124                    return Err(ProtoError::HeaderName);
125                }
126                for val in value.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
127                    let val = val.trim();
128                    if val.eq_ignore_ascii_case("chunked") {
129                        decoder.try_set(TransferCoding::decode_chunked(MAX_HEADERS))?;
130                    }
131                }
132            }
133            CONTENT_LENGTH => {
134                let len = header::parse_content_length(&value)?;
135                decoder.try_set(TransferCoding::length(len))?;
136            }
137            CONNECTION => self.try_set_close_from_header(&value)?,
138            EXPECT => {
139                if !value.as_bytes().eq_ignore_ascii_case(b"100-continue") {
140                    return Err(ProtoError::HeaderValue);
141                }
142                self.set_expect_header()
143            }
144            UPGRADE => {
145                if version != Version::HTTP_11 {
146                    return Err(ProtoError::HeaderName);
147                }
148                decoder.try_set(TransferCoding::upgrade())?;
149            }
150            _ => {}
151        }
152
153        headers.append(name, value);
154
155        Ok(())
156    }
157
158    pub(super) fn try_set_close_from_header(&mut self, val: &HeaderValue) -> Result<(), ProtoError> {
159        for val in val.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
160            let val = val.trim();
161            if val.eq_ignore_ascii_case("keep-alive") {
162                self.remove_close()
163            } else if val.eq_ignore_ascii_case("close") {
164                self.set_close()
165            }
166        }
167        Ok(())
168    }
169}
170
171#[cfg(test)]
172mod test {
173    use super::*;
174
175    #[test]
176    fn connection_multiple_value() {
177        let mut ctx = Context::<_, 4>::new(&());
178
179        let head = b"\
180                GET / HTTP/1.1\r\n\
181                Connection: keep-alive, upgrade\r\n\
182                \r\n\
183                ";
184        let mut buf = BytesMut::from(&head[..]);
185
186        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
187        assert!(!ctx.is_connection_closed());
188
189        // this is a wrong connection header but instead of rejecting it the last close value is
190        // used for the final value. there is no particular reason behind this behaviour and this
191        // session of the test serves only as a consistency check to prevent regression.
192        let head = b"\
193                GET / HTTP/1.1\r\n\
194                Connection: keep-alive, close, upgrade\r\n\
195                \r\n\
196                ";
197        let mut buf = BytesMut::from(&head[..]);
198
199        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
200        assert!(ctx.is_connection_closed());
201
202        let head = b"\
203                GET / HTTP/1.1\r\n\
204                Connection: close, keep-alive, upgrade\r\n\
205                \r\n\
206                ";
207        let mut buf = BytesMut::from(&head[..]);
208
209        let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
210        assert!(!ctx.is_connection_closed());
211    }
212
213    #[test]
214    fn transfer_encoding() {
215        let mut ctx = Context::<_, 4>::new(&());
216
217        let head = b"\
218                GET / HTTP/1.1\r\n\
219                Transfer-Encoding: gzip\r\n\
220                Transfer-Encoding: chunked\r\n\
221                \r\n\
222                ";
223        let mut buf = BytesMut::from(&head[..]);
224
225        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
226        let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
227        assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
228        assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
229        assert!(
230            matches!(decoder, TransferCoding::DecodeChunked { .. }),
231            "transfer coding is not decoded to chunked"
232        );
233
234        ctx.reset();
235
236        let head = b"\
237        GET / HTTP/1.1\r\n\
238        Transfer-Encoding: chunked\r\n\
239        Transfer-Encoding: gzip\r\n\
240        \r\n\
241        ";
242        let mut buf = BytesMut::from(&head[..]);
243
244        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
245        let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
246        assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
247        assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
248        assert!(
249            matches!(decoder, TransferCoding::DecodeChunked { .. }),
250            "transfer coding is not decoded to chunked"
251        );
252
253        ctx.reset();
254
255        let head = b"\
256                GET / HTTP/1.1\r\n\
257                Transfer-Encoding: gzip, chunked\r\n\
258                \r\n\
259                ";
260        let mut buf = BytesMut::from(&head[..]);
261
262        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
263        assert_eq!(
264            req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
265            "gzip, chunked"
266        );
267        assert!(
268            matches!(decoder, TransferCoding::DecodeChunked { .. }),
269            "transfer coding is not decoded to chunked"
270        );
271
272        ctx.reset();
273
274        let head = b"\
275                GET / HTTP/1.1\r\n\
276                Transfer-Encoding: chunked\r\n\
277                \r\n\
278                ";
279        let mut buf = BytesMut::from(&head[..]);
280
281        let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
282        assert_eq!(
283            req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
284            "chunked"
285        );
286        assert!(
287            matches!(decoder, TransferCoding::DecodeChunked { .. }),
288            "transfer coding is not decoded to chunked"
289        );
290
291        let head = b"\
292                GET / HTTP/1.1\r\n\
293                Transfer-Encoding: identity\r\n\
294                \r\n\
295                ";
296        let mut buf = BytesMut::from(&head[..]);
297        let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
298        assert!(
299            matches!(decoder, TransferCoding::Eof),
300            "transfer coding is not decoded to eof"
301        );
302
303        let head = b"\
304                GET / HTTP/1.1\r\n\
305                Transfer-Encoding: chunked, gzip\r\n\
306                \r\n\
307                ";
308        let mut buf = BytesMut::from(&head[..]);
309        let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
310        assert!(
311            matches!(decoder, TransferCoding::DecodeChunked { .. }),
312            "transfer coding is not decoded to chunked"
313        );
314    }
315}