1use core::mem::MaybeUninit;
2
3use httparse::Status;
4
5use crate::{
6 bytes::{Buf, Bytes, BytesMut},
7 http::{
8 Extension, Method, Request, RequestExt, Uri, Version,
9 header::{CONNECTION, CONTENT_LENGTH, EXPECT, HeaderMap, HeaderName, HeaderValue, TRANSFER_ENCODING, UPGRADE},
10 },
11};
12
13use super::{
14 codec::TransferCoding,
15 context::Context,
16 error::ProtoError,
17 header::{self, HeaderIndex},
18};
19
20type Decoded = (Request<RequestExt<()>>, TransferCoding);
21
22impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
23 pub fn decode_head<const READ_BUF_LIMIT: usize>(
25 &mut self,
26 buf: &mut BytesMut,
27 ) -> Result<Option<Decoded>, ProtoError> {
28 let mut req = httparse::Request::new(&mut []);
29 let mut headers = [const { MaybeUninit::uninit() }; MAX_HEADERS];
30
31 match req.parse_with_uninit_headers(buf, &mut headers)? {
32 Status::Complete(len) => {
33 self.reset();
35
36 let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
37
38 let mut decoder = match method {
40 Method::CONNECT => {
42 self.set_connect_method();
43 TransferCoding::upgrade()
44 }
45 Method::HEAD => {
46 self.set_head_method();
47 TransferCoding::eof()
48 }
49 _ => TransferCoding::eof(),
50 };
51
52 let version = if req.version.unwrap() == 1 {
54 Version::HTTP_11
56 } else {
57 self.set_close();
58 Version::HTTP_10
59 };
60
61 let mut header_idx = [const { MaybeUninit::uninit() }; MAX_HEADERS];
63 let header_idx_slice = HeaderIndex::record(&mut header_idx, buf, req.headers);
64 let headers_len = req.headers.len();
65
66 let path = req.path.unwrap();
68 let path_head = path.as_ptr() as usize - buf.as_ptr() as usize;
69 let path_len = path.len();
70
71 let slice = buf.split_to(len).freeze();
73
74 let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
75
76 let mut headers = self.take_headers();
78 headers.reserve(headers_len);
79
80 for idx in header_idx_slice {
82 self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
83 }
84
85 let ext = Extension::new(*self.socket_addr());
86 let mut req = Request::new(RequestExt::from_parts((), ext));
87
88 let extensions = self.take_extensions();
89
90 *req.method_mut() = method;
91 *req.version_mut() = version;
92 *req.uri_mut() = uri;
93 *req.headers_mut() = headers;
94 *req.extensions_mut() = extensions;
95
96 Ok(Some((req, decoder)))
97 }
98
99 Status::Partial => {
100 if buf.remaining() >= READ_BUF_LIMIT {
101 Err(ProtoError::HeaderTooLarge)
102 } else {
103 Ok(None)
104 }
105 }
106 }
107 }
108
109 pub fn try_write_header(
110 &mut self,
111 headers: &mut HeaderMap,
112 decoder: &mut TransferCoding,
113 idx: &HeaderIndex,
114 slice: &Bytes,
115 version: Version,
116 ) -> Result<(), ProtoError> {
117 let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap();
118 let value = HeaderValue::from_maybe_shared(slice.slice(idx.value.0..idx.value.1)).unwrap();
119
120 match name {
121 TRANSFER_ENCODING => {
122 if version != Version::HTTP_11 {
123 return Err(ProtoError::HeaderName);
124 }
125 for val in value.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
126 let val = val.trim();
127 if val.eq_ignore_ascii_case("chunked") {
128 decoder.try_set(TransferCoding::decode_chunked())?;
129 }
130 }
131 }
132 CONTENT_LENGTH => {
133 let len = header::parse_content_length(&value)?;
134 decoder.try_set(TransferCoding::length(len))?;
135 }
136 CONNECTION => self.try_set_close_from_header(&value)?,
137 EXPECT => {
138 if !value.as_bytes().eq_ignore_ascii_case(b"100-continue") {
139 return Err(ProtoError::HeaderValue);
140 }
141 self.set_expect_header()
142 }
143 UPGRADE => {
144 if version != Version::HTTP_11 {
145 return Err(ProtoError::HeaderName);
146 }
147 decoder.try_set(TransferCoding::upgrade())?;
148 }
149 _ => {}
150 }
151
152 headers.append(name, value);
153
154 Ok(())
155 }
156
157 pub(super) fn try_set_close_from_header(&mut self, val: &HeaderValue) -> Result<(), ProtoError> {
158 for val in val.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
159 let val = val.trim();
160 if val.eq_ignore_ascii_case("keep-alive") {
161 self.remove_close()
162 } else if val.eq_ignore_ascii_case("close") {
163 self.set_close()
164 }
165 }
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod test {
172 use super::*;
173
174 #[test]
175 fn connection_multiple_value() {
176 let mut ctx = Context::<_, 4>::new(&());
177
178 let head = b"\
179 GET / HTTP/1.1\r\n\
180 Connection: keep-alive, upgrade\r\n\
181 \r\n\
182 ";
183 let mut buf = BytesMut::from(&head[..]);
184
185 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
186 assert!(!ctx.is_connection_closed());
187
188 let head = b"\
192 GET / HTTP/1.1\r\n\
193 Connection: keep-alive, close, upgrade\r\n\
194 \r\n\
195 ";
196 let mut buf = BytesMut::from(&head[..]);
197
198 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
199 assert!(ctx.is_connection_closed());
200
201 let head = b"\
202 GET / HTTP/1.1\r\n\
203 Connection: close, keep-alive, upgrade\r\n\
204 \r\n\
205 ";
206 let mut buf = BytesMut::from(&head[..]);
207
208 let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
209 assert!(!ctx.is_connection_closed());
210 }
211
212 #[test]
213 fn transfer_encoding() {
214 let mut ctx = Context::<_, 4>::new(&());
215
216 let head = b"\
217 GET / HTTP/1.1\r\n\
218 Transfer-Encoding: gzip\r\n\
219 Transfer-Encoding: chunked\r\n\
220 \r\n\
221 ";
222 let mut buf = BytesMut::from(&head[..]);
223
224 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
225 let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
226 assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
227 assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
228 assert!(
229 matches!(decoder, TransferCoding::DecodeChunked(..)),
230 "transfer coding is not decoded to chunked"
231 );
232
233 ctx.reset();
234
235 let head = b"\
236 GET / HTTP/1.1\r\n\
237 Transfer-Encoding: chunked\r\n\
238 Transfer-Encoding: gzip\r\n\
239 \r\n\
240 ";
241 let mut buf = BytesMut::from(&head[..]);
242
243 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
244 let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
245 assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
246 assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
247 assert!(
248 matches!(decoder, TransferCoding::DecodeChunked(..)),
249 "transfer coding is not decoded to chunked"
250 );
251
252 ctx.reset();
253
254 let head = b"\
255 GET / HTTP/1.1\r\n\
256 Transfer-Encoding: gzip, chunked\r\n\
257 \r\n\
258 ";
259 let mut buf = BytesMut::from(&head[..]);
260
261 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
262 assert_eq!(
263 req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
264 "gzip, chunked"
265 );
266 assert!(
267 matches!(decoder, TransferCoding::DecodeChunked(..)),
268 "transfer coding is not decoded to chunked"
269 );
270
271 ctx.reset();
272
273 let head = b"\
274 GET / HTTP/1.1\r\n\
275 Transfer-Encoding: chunked\r\n\
276 \r\n\
277 ";
278 let mut buf = BytesMut::from(&head[..]);
279
280 let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
281 assert_eq!(
282 req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
283 "chunked"
284 );
285 assert!(
286 matches!(decoder, TransferCoding::DecodeChunked(..)),
287 "transfer coding is not decoded to chunked"
288 );
289
290 let head = b"\
291 GET / HTTP/1.1\r\n\
292 Transfer-Encoding: identity\r\n\
293 \r\n\
294 ";
295 let mut buf = BytesMut::from(&head[..]);
296 let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
297 assert!(
298 matches!(decoder, TransferCoding::Eof),
299 "transfer coding is not decoded to eof"
300 );
301
302 let head = b"\
303 GET / HTTP/1.1\r\n\
304 Transfer-Encoding: chunked, gzip\r\n\
305 \r\n\
306 ";
307 let mut buf = BytesMut::from(&head[..]);
308 let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
309 assert!(
310 matches!(decoder, TransferCoding::DecodeChunked(..)),
311 "transfer coding is not decoded to chunked"
312 );
313 }
314}