1use core::mem::MaybeUninit;
2
3use httparse::Status;
4
5use crate::{
6 bytes::{Buf, Bytes, BytesMut},
7 http::{
8 Extension, Method, Request, RequestExt, Uri, Version,
9 header::{CONNECTION, CONTENT_LENGTH, EXPECT, HeaderMap, HeaderName, HeaderValue, TRANSFER_ENCODING, UPGRADE},
10 },
11};
12
13use super::{
14 context::Context,
15 error::ProtoError,
16 header::{self, HeaderIndex},
17 trasnder_coding::TransferCoding,
18};
19
20type Decoded = (Request<RequestExt<()>>, TransferCoding);
21
22impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
23 pub fn decode_head<const READ_BUF_LIMIT: usize>(
25 &mut self,
26 buf: &mut BytesMut,
27 ) -> Result<Option<Decoded>, ProtoError> {
28 let mut req = httparse::Request::new(&mut []);
29 let mut headers = [const { MaybeUninit::uninit() }; MAX_HEADERS];
30
31 match req.parse_with_uninit_headers(buf, &mut headers)? {
32 Status::Complete(len) => {
33 self.reset();
35
36 let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
37
38 let mut decoder = match method {
40 Method::CONNECT => {
42 self.set_connect_method();
43 TransferCoding::upgrade()
44 }
45 Method::HEAD => {
46 self.set_head_method();
47 TransferCoding::eof()
48 }
49 _ => TransferCoding::eof(),
50 };
51
52 let version = if req.version.unwrap() == 1 {
54 Version::HTTP_11
56 } else {
57 self.set_http10();
58 self.set_close();
59 Version::HTTP_10
60 };
61
62 let mut header_idx = [const { MaybeUninit::uninit() }; MAX_HEADERS];
64 let header_idx_slice = HeaderIndex::record(&mut header_idx, buf, req.headers);
65 let headers_len = req.headers.len();
66
67 let path = req.path.unwrap();
69 let path_head = path.as_ptr() as usize - buf.as_ptr() as usize;
70 let path_len = path.len();
71
72 let slice = buf.split_to(len).freeze();
74
75 let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
76
77 let mut headers = self.take_headers();
79 headers.reserve(headers_len);
80
81 for idx in header_idx_slice {
83 self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
84 }
85
86 let ext = Extension::new(*self.socket_addr());
87 let mut req = Request::new(RequestExt::from_parts((), ext));
88
89 let extensions = self.take_extensions();
90
91 *req.method_mut() = method;
92 *req.version_mut() = version;
93 *req.uri_mut() = uri;
94 *req.headers_mut() = headers;
95 *req.extensions_mut() = extensions;
96
97 Ok(Some((req, decoder)))
98 }
99
100 Status::Partial => {
101 if buf.remaining() >= READ_BUF_LIMIT {
102 Err(ProtoError::HeaderTooLarge)
103 } else {
104 Ok(None)
105 }
106 }
107 }
108 }
109
110 pub fn try_write_header(
111 &mut self,
112 headers: &mut HeaderMap,
113 decoder: &mut TransferCoding,
114 idx: &HeaderIndex,
115 slice: &Bytes,
116 version: Version,
117 ) -> Result<(), ProtoError> {
118 let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap();
119 let value = HeaderValue::from_maybe_shared(slice.slice(idx.value.0..idx.value.1)).unwrap();
120
121 match name {
122 TRANSFER_ENCODING => {
123 if version != Version::HTTP_11 {
124 return Err(ProtoError::HeaderName);
125 }
126 for val in value.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
127 let val = val.trim();
128 if val.eq_ignore_ascii_case("chunked") {
129 decoder.try_set(TransferCoding::decode_chunked(MAX_HEADERS))?;
130 }
131 }
132 }
133 CONTENT_LENGTH => {
134 let len = header::parse_content_length(&value)?;
135 decoder.try_set(TransferCoding::length(len))?;
136 }
137 CONNECTION => self.try_set_close_from_header(&value)?,
138 EXPECT => {
139 if !value.as_bytes().eq_ignore_ascii_case(b"100-continue") {
140 return Err(ProtoError::HeaderValue);
141 }
142 self.set_expect_header()
143 }
144 UPGRADE => {
145 if version != Version::HTTP_11 {
146 return Err(ProtoError::HeaderName);
147 }
148 decoder.try_set(TransferCoding::upgrade())?;
149 }
150 _ => {}
151 }
152
153 headers.append(name, value);
154
155 Ok(())
156 }
157
158 pub(super) fn try_set_close_from_header(&mut self, val: &HeaderValue) -> Result<(), ProtoError> {
159 for val in val.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
160 let val = val.trim();
161 if val.eq_ignore_ascii_case("keep-alive") {
162 self.remove_close()
163 } else if val.eq_ignore_ascii_case("close") {
164 self.set_close()
165 }
166 }
167 Ok(())
168 }
169}
170
171#[cfg(test)]
172mod test {
173 use super::*;
174
175 #[test]
176 fn connection_multiple_value() {
177 let mut ctx = Context::<_, 4>::new(&());
178
179 let head = b"\
180 GET / HTTP/1.1\r\n\
181 Connection: keep-alive, upgrade\r\n\
182 \r\n\
183 ";
184 let mut buf = BytesMut::from(&head[..]);
185
186 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
187 assert!(!ctx.is_connection_closed());
188
189 let head = b"\
193 GET / HTTP/1.1\r\n\
194 Connection: keep-alive, close, upgrade\r\n\
195 \r\n\
196 ";
197 let mut buf = BytesMut::from(&head[..]);
198
199 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
200 assert!(ctx.is_connection_closed());
201
202 let head = b"\
203 GET / HTTP/1.1\r\n\
204 Connection: close, keep-alive, upgrade\r\n\
205 \r\n\
206 ";
207 let mut buf = BytesMut::from(&head[..]);
208
209 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
210 assert!(!ctx.is_connection_closed());
211 }
212
213 #[test]
214 fn transfer_encoding() {
215 let mut ctx = Context::<_, 4>::new(&());
216
217 let head = b"\
218 GET / HTTP/1.1\r\n\
219 Transfer-Encoding: gzip\r\n\
220 Transfer-Encoding: chunked\r\n\
221 \r\n\
222 ";
223 let mut buf = BytesMut::from(&head[..]);
224
225 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
226 let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
227 assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
228 assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
229 assert!(
230 matches!(decoder, TransferCoding::DecodeChunked { .. }),
231 "transfer coding is not decoded to chunked"
232 );
233
234 ctx.reset();
235
236 let head = b"\
237 GET / HTTP/1.1\r\n\
238 Transfer-Encoding: chunked\r\n\
239 Transfer-Encoding: gzip\r\n\
240 \r\n\
241 ";
242 let mut buf = BytesMut::from(&head[..]);
243
244 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
245 let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
246 assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
247 assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
248 assert!(
249 matches!(decoder, TransferCoding::DecodeChunked { .. }),
250 "transfer coding is not decoded to chunked"
251 );
252
253 ctx.reset();
254
255 let head = b"\
256 GET / HTTP/1.1\r\n\
257 Transfer-Encoding: gzip, chunked\r\n\
258 \r\n\
259 ";
260 let mut buf = BytesMut::from(&head[..]);
261
262 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
263 assert_eq!(
264 req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
265 "gzip, chunked"
266 );
267 assert!(
268 matches!(decoder, TransferCoding::DecodeChunked { .. }),
269 "transfer coding is not decoded to chunked"
270 );
271
272 ctx.reset();
273
274 let head = b"\
275 GET / HTTP/1.1\r\n\
276 Transfer-Encoding: chunked\r\n\
277 \r\n\
278 ";
279 let mut buf = BytesMut::from(&head[..]);
280
281 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
282 assert_eq!(
283 req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
284 "chunked"
285 );
286 assert!(
287 matches!(decoder, TransferCoding::DecodeChunked { .. }),
288 "transfer coding is not decoded to chunked"
289 );
290
291 let head = b"\
292 GET / HTTP/1.1\r\n\
293 Transfer-Encoding: identity\r\n\
294 \r\n\
295 ";
296 let mut buf = BytesMut::from(&head[..]);
297 let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
298 assert!(
299 matches!(decoder, TransferCoding::Eof),
300 "transfer coding is not decoded to eof"
301 );
302
303 let head = b"\
304 GET / HTTP/1.1\r\n\
305 Transfer-Encoding: chunked, gzip\r\n\
306 \r\n\
307 ";
308 let mut buf = BytesMut::from(&head[..]);
309 let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
310 assert!(
311 matches!(decoder, TransferCoding::DecodeChunked { .. }),
312 "transfer coding is not decoded to chunked"
313 );
314 }
315}