use httparse::Status;
use xitca_unsafe_collection::uninit;
use crate::{
bytes::{Buf, Bytes, BytesMut},
http::{
header::{HeaderMap, HeaderName, HeaderValue, CONNECTION, CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING, UPGRADE},
Extension, Method, Request, RequestExt, Uri, Version,
},
};
use super::{
codec::TransferCoding,
context::Context,
error::ProtoError,
header::{self, HeaderIndex},
};
type Decoded = (Request<RequestExt<()>>, TransferCoding);
impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
pub fn decode_head<const READ_BUF_LIMIT: usize>(
&mut self,
buf: &mut BytesMut,
) -> Result<Option<Decoded>, ProtoError> {
let mut req = httparse::Request::new(&mut []);
let mut headers = uninit::uninit_array::<_, MAX_HEADERS>();
match req.parse_with_uninit_headers(buf, &mut headers)? {
Status::Complete(len) => {
self.reset();
let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
let uri = req.path.unwrap().parse::<Uri>()?;
let mut decoder = match method {
Method::CONNECT => {
self.set_connect_method();
TransferCoding::upgrade()
}
Method::HEAD => {
self.set_head_method();
TransferCoding::eof()
}
_ => TransferCoding::eof(),
};
let version = if req.version.unwrap() == 1 {
Version::HTTP_11
} else {
self.set_close();
Version::HTTP_10
};
let mut header_idx = uninit::uninit_array::<_, MAX_HEADERS>();
let header_idx_slice = HeaderIndex::record(&mut header_idx, buf, req.headers);
let headers_len = req.headers.len();
let slice = buf.split_to(len).freeze();
let mut headers = self.take_headers();
headers.reserve(headers_len);
for idx in header_idx_slice {
self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
}
let ext = Extension::new(*self.socket_addr());
let mut req = Request::new(RequestExt::from_parts((), ext));
let extensions = self.take_extensions();
*req.method_mut() = method;
*req.version_mut() = version;
*req.uri_mut() = uri;
*req.headers_mut() = headers;
*req.extensions_mut() = extensions;
Ok(Some((req, decoder)))
}
Status::Partial => {
if buf.remaining() >= READ_BUF_LIMIT {
Err(ProtoError::HeaderTooLarge)
} else {
Ok(None)
}
}
}
}
pub fn try_write_header(
&mut self,
headers: &mut HeaderMap,
decoder: &mut TransferCoding,
idx: &HeaderIndex,
slice: &Bytes,
version: Version,
) -> Result<(), ProtoError> {
let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap();
let value = HeaderValue::from_maybe_shared(slice.slice(idx.value.0..idx.value.1)).unwrap();
match name {
TRANSFER_ENCODING => {
if version != Version::HTTP_11 {
return Err(ProtoError::HeaderName);
}
for val in value.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
let val = val.trim();
if val.eq_ignore_ascii_case("chunked") {
decoder.try_set(TransferCoding::decode_chunked())?;
}
}
}
CONTENT_LENGTH => {
let len = header::parse_content_length(&value)?;
decoder.try_set(TransferCoding::length(len))?;
}
CONNECTION => self.try_set_close_from_header(&value)?,
EXPECT => {
if !value.as_bytes().eq_ignore_ascii_case(b"100-continue") {
return Err(ProtoError::HeaderValue);
}
self.set_expect_header()
}
UPGRADE => {
if version != Version::HTTP_11 {
return Err(ProtoError::HeaderName);
}
decoder.try_set(TransferCoding::upgrade())?;
}
_ => {}
}
headers.append(name, value);
Ok(())
}
pub(super) fn try_set_close_from_header(&mut self, val: &HeaderValue) -> Result<(), ProtoError> {
for val in val.to_str().map_err(|_| ProtoError::HeaderValue)?.split(',') {
let val = val.trim();
if val.eq_ignore_ascii_case("keep-alive") {
self.remove_close()
} else if val.eq_ignore_ascii_case("close") {
self.set_close()
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn connection_multiple_value() {
let mut ctx = Context::<_, 4>::new(&());
let head = b"\
GET / HTTP/1.1\r\n\
Connection: keep-alive, upgrade\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert!(!ctx.is_connection_closed());
let head = b"\
GET / HTTP/1.1\r\n\
Connection: keep-alive, close, upgrade\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert!(ctx.is_connection_closed());
let head = b"\
GET / HTTP/1.1\r\n\
Connection: close, keep-alive, upgrade\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let _ = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert!(!ctx.is_connection_closed());
}
#[test]
fn transfer_encoding() {
let mut ctx = Context::<_, 4>::new(&());
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: gzip\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
assert!(
matches!(decoder, TransferCoding::DecodeChunked(..)),
"transfer coding is not decoded to chunked"
);
ctx.reset();
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: chunked\r\n\
Transfer-Encoding: gzip\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
let mut iter = req.headers().get_all(TRANSFER_ENCODING).into_iter();
assert_eq!(iter.next().unwrap().to_str().unwrap(), "chunked");
assert_eq!(iter.next().unwrap().to_str().unwrap(), "gzip");
assert!(
matches!(decoder, TransferCoding::DecodeChunked(..)),
"transfer coding is not decoded to chunked"
);
ctx.reset();
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: gzip, chunked\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert_eq!(
req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
"gzip, chunked"
);
assert!(
matches!(decoder, TransferCoding::DecodeChunked(..)),
"transfer coding is not decoded to chunked"
);
ctx.reset();
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (req, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert_eq!(
req.headers().get(TRANSFER_ENCODING).unwrap().to_str().unwrap(),
"chunked"
);
assert!(
matches!(decoder, TransferCoding::DecodeChunked(..)),
"transfer coding is not decoded to chunked"
);
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: identity\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert!(
matches!(decoder, TransferCoding::Eof),
"transfer coding is not decoded to eof"
);
let head = b"\
GET / HTTP/1.1\r\n\
Transfer-Encoding: chunked, gzip\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);
let (_, decoder) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();
assert!(
matches!(decoder, TransferCoding::DecodeChunked(..)),
"transfer coding is not decoded to chunked"
);
}
}