use crate::{copy, http_config::DEFAULT_CONFIG, Body, HttpConfig, MutCow};
use encoding_rs::Encoding;
use futures_lite::{ready, AsyncRead, AsyncReadExt, AsyncWrite, Stream};
use httparse::{InvalidChunkSize, Status};
use std::{
fmt::{self, Debug, Formatter},
future::{Future, IntoFuture},
io::{self, ErrorKind},
iter,
pin::Pin,
task::{Context, Poll},
};
use Poll::{Pending, Ready};
use ReceivedBodyState::{Chunked, End, FixedLength, PartialChunkSize, Start};
#[cfg(test)]
mod tests;
macro_rules! trace {
($s:literal, $($arg:tt)+) => (
log::trace!(concat!(":{} ", $s), line!(), $($arg)+);
)
}
pub struct ReceivedBody<'conn, Transport> {
content_length: Option<u64>,
buffer: MutCow<'conn, Option<Vec<u8>>>,
transport: Option<MutCow<'conn, Transport>>,
state: MutCow<'conn, ReceivedBodyState>,
on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
encoding: &'static Encoding,
max_len: u64,
initial_len: usize,
copy_loops_per_yield: usize,
max_preallocate: usize,
}
fn slice_from(min: u64, buf: &[u8]) -> Option<&[u8]> {
buf.get(usize::try_from(min).unwrap_or(usize::MAX)..)
.filter(|buf| !buf.is_empty())
}
impl<'conn, Transport> ReceivedBody<'conn, Transport>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
#[allow(missing_docs)]
#[doc(hidden)]
pub fn new(
content_length: Option<u64>,
buffer: impl Into<MutCow<'conn, Option<Vec<u8>>>>,
transport: impl Into<MutCow<'conn, Transport>>,
state: impl Into<MutCow<'conn, ReceivedBodyState>>,
on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
encoding: &'static Encoding,
) -> Self {
Self::new_with_config(
content_length,
buffer,
transport,
state,
on_completion,
encoding,
&DEFAULT_CONFIG,
)
}
#[allow(missing_docs)]
#[doc(hidden)]
pub(crate) fn new_with_config(
content_length: Option<u64>,
buffer: impl Into<MutCow<'conn, Option<Vec<u8>>>>,
transport: impl Into<MutCow<'conn, Transport>>,
state: impl Into<MutCow<'conn, ReceivedBodyState>>,
on_completion: Option<Box<dyn Fn(Transport) + Send + Sync + 'static>>,
encoding: &'static Encoding,
config: &HttpConfig,
) -> Self {
Self {
content_length,
buffer: buffer.into(),
transport: Some(transport.into()),
state: state.into(),
on_completion,
encoding,
max_len: config.received_body_max_len,
initial_len: config.received_body_initial_len,
copy_loops_per_yield: config.copy_loops_per_yield,
max_preallocate: config.received_body_max_preallocate,
}
}
pub fn content_length(&self) -> Option<u64> {
self.content_length
}
pub async fn read_string(self) -> crate::Result<String> {
let encoding = self.encoding();
let bytes = self.read_bytes().await?;
let (s, _, _) = encoding.decode(&bytes);
Ok(s.to_string())
}
fn owns_transport(&self) -> bool {
self.transport
.as_ref()
.map(MutCow::is_owned)
.unwrap_or_default()
}
pub fn set_max_len(&mut self, max_len: u64) {
self.max_len = max_len;
}
#[must_use]
pub fn with_max_len(mut self, max_len: u64) -> Self {
self.set_max_len(max_len);
self
}
pub async fn read_bytes(mut self) -> crate::Result<Vec<u8>> {
let mut vec = if let Some(len) = self.content_length {
if len > self.max_len {
return Err(crate::Error::ReceivedBodyTooLong(self.max_len));
}
let len = usize::try_from(len)
.map_err(|_| crate::Error::ReceivedBodyTooLong(self.max_len))?;
Vec::with_capacity(len.min(self.max_preallocate))
} else {
Vec::with_capacity(self.initial_len)
};
self.read_to_end(&mut vec).await?;
Ok(vec)
}
pub fn encoding(&self) -> &'static Encoding {
self.encoding
}
fn read_raw(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
if let Some(transport) = self.transport.as_deref_mut() {
read_raw(&mut self.buffer, transport, cx, buf)
} else {
Ready(Err(ErrorKind::NotConnected.into()))
}
}
#[allow(clippy::missing_errors_doc)] pub async fn drain(self) -> io::Result<u64> {
let copy_loops_per_yield = self.copy_loops_per_yield;
copy(self, futures_lite::io::sink(), copy_loops_per_yield).await
}
}
impl<'a, Transport> IntoFuture for ReceivedBody<'a, Transport>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
type Output = crate::Result<String>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.read_string().await })
}
}
impl<T> ReceivedBody<'static, T> {
pub fn take_transport(&mut self) -> Option<T> {
self.transport.take().map(MutCow::unwrap_owned)
}
}
fn read_raw<Transport>(
opt_buffer: &mut Option<Vec<u8>>,
transport: &mut Transport,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
match opt_buffer {
Some(buffer) if !buffer.is_empty() => {
let len = buffer.len();
if len > buf.len() {
trace!(
"have {} bytes of pending data but can only use {}",
len,
buf.len()
);
let remaining = buffer.split_off(buf.len());
buf.copy_from_slice(buffer);
*buffer = remaining;
Ready(Ok(buf.len()))
} else {
trace!("have {} bytes of pending data, using all of it", len);
buf[..len].copy_from_slice(buffer);
*opt_buffer = None;
match Pin::new(transport).poll_read(cx, &mut buf[len..]) {
Ready(Ok(e)) => Ready(Ok(e + len)),
Pending => Ready(Ok(len)),
other @ Ready(_) => other,
}
}
}
_ => Pin::new(transport).poll_read(cx, buf),
}
}
fn chunk_decode(
remaining: u64,
mut total: u64,
buf: &mut [u8],
max_len: u64,
) -> io::Result<(ReceivedBodyState, usize, Option<Vec<u8>>)> {
if buf.is_empty() {
return Err(io::Error::from(ErrorKind::ConnectionAborted));
}
let mut ranges_to_keep = vec![];
let mut chunk_start = 0u64;
let mut chunk_end = remaining;
let (request_body_state, unused) = loop {
if chunk_end > 2 {
let keep_start = usize::try_from(chunk_start).unwrap_or(usize::MAX);
let keep_end = buf
.len()
.min(usize::try_from(chunk_end - 2).unwrap_or(usize::MAX));
ranges_to_keep.push(keep_start..keep_end);
let new_bytes = (keep_end - keep_start) as u64;
total += new_bytes;
if total > max_len {
return Err(io::Error::new(ErrorKind::Unsupported, "content too long"));
}
}
chunk_start = chunk_end;
let Some(buf_to_read) = slice_from(chunk_start, buf) else {
break (
Chunked {
remaining: (chunk_start - buf.len() as u64),
total,
},
None,
);
};
if buf_to_read.is_empty() {
break (
Chunked {
remaining: (chunk_start - buf.len() as u64),
total,
},
None,
);
}
match httparse::parse_chunk_size(buf_to_read) {
Ok(Status::Complete((framing_bytes, chunk_size))) => {
chunk_start += framing_bytes as u64;
chunk_end = (2 + chunk_start)
.checked_add(chunk_size)
.ok_or_else(|| io::Error::new(ErrorKind::InvalidData, "chunk size too long"))?;
if chunk_size == 0 {
break (End, slice_from(chunk_end, buf).map(Vec::from));
}
}
Ok(Status::Partial) => {
break (PartialChunkSize { total }, Some(Vec::from(buf_to_read)));
}
Err(InvalidChunkSize) => {
return Err(io::Error::new(ErrorKind::InvalidData, "invalid chunk size"));
}
}
};
let mut bytes = 0;
for range_to_keep in ranges_to_keep {
let new_bytes = bytes + range_to_keep.end - range_to_keep.start;
buf.copy_within(range_to_keep, bytes);
bytes = new_bytes;
}
Ok((request_body_state, bytes, unused))
}
const STREAM_READ_BUF_LENGTH: usize = 128;
impl<'conn, Transport> Stream for ReceivedBody<'conn, Transport>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
type Item = Vec<u8>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut bytes = 0;
let mut vec = vec![0; STREAM_READ_BUF_LENGTH];
loop {
match Pin::new(&mut *self).poll_read(cx, &mut vec[bytes..]) {
Pending if bytes == 0 => return Pending,
Ready(Ok(0)) if bytes == 0 => return Ready(None),
Pending | Ready(Ok(0)) => {
vec.truncate(bytes);
return Ready(Some(vec));
}
Ready(Ok(new_bytes)) => {
bytes += new_bytes;
vec.extend(iter::repeat(0).take(bytes + STREAM_READ_BUF_LENGTH - vec.len()));
}
Ready(Err(error)) => {
log::error!("got {error:?} in ReceivedBody stream");
return Ready(None);
}
}
}
}
}
type StateOutput = Poll<io::Result<(ReceivedBodyState, usize, Option<Vec<u8>>)>>;
impl<'conn, Transport> ReceivedBody<'conn, Transport>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
#[inline]
fn handle_start(&mut self) -> StateOutput {
Ready(Ok((
match self.content_length {
Some(0) => End,
Some(total_length) if total_length < self.max_len => FixedLength {
current_index: 0,
total: total_length,
},
Some(_) => {
return Ready(Err(io::Error::new(
ErrorKind::Unsupported,
"content too long",
)))
}
None => Chunked {
remaining: 0,
total: 0,
},
},
0,
None,
)))
}
#[inline]
fn handle_chunked(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
remaining: u64,
total: u64,
) -> StateOutput {
let bytes = ready!(self.read_raw(cx, buf)?);
Ready(chunk_decode(
remaining,
total,
&mut buf[..bytes],
self.max_len,
))
}
#[inline]
fn handle_partial(&mut self, cx: &mut Context<'_>, buf: &mut [u8], total: u64) -> StateOutput {
let transport = self
.transport
.as_deref_mut()
.ok_or_else(|| io::Error::from(ErrorKind::NotConnected))?;
let bytes = ready!(Pin::new(transport).poll_read(cx, buf))?;
if bytes == 0 {
return Ready(Err(io::Error::from(ErrorKind::ConnectionAborted)));
}
let mut inner_buf = self.buffer.take().unwrap_or_default();
inner_buf.extend_from_slice(&buf[..bytes]);
match httparse::parse_chunk_size(&inner_buf) {
Ok(Status::Complete((framing_bytes, 0))) => {
Ready(Ok((End, 0, Some(Vec::from(&inner_buf[framing_bytes..])))))
}
Ok(Status::Complete((framing_bytes, remaining))) => Ready(Ok((
Chunked {
remaining: remaining + 2,
total,
},
0,
Some(Vec::from(&inner_buf[framing_bytes..])),
))),
Ok(Status::Partial) => Ready(Ok((PartialChunkSize { total }, 0, Some(inner_buf)))),
Err(InvalidChunkSize) => Ready(Err(io::Error::new(
ErrorKind::InvalidData,
"invalid chunk framing",
))),
}
}
#[inline]
fn handle_fixed_length(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
current_index: u64,
total_length: u64,
) -> StateOutput {
let len = buf.len();
let remaining = usize::try_from(total_length - current_index).unwrap_or(usize::MAX);
let buf = &mut buf[..len.min(remaining)];
let bytes = ready!(self.read_raw(cx, buf)?);
let current_index = current_index + bytes as u64;
let state = if bytes == 0 || current_index == total_length {
End
} else {
FixedLength {
current_index,
total: total_length,
}
};
Ready(Ok((state, bytes, None)))
}
}
impl<'conn, Transport> AsyncRead for ReceivedBody<'conn, Transport>
where
Transport: AsyncRead + Unpin + Send + Sync + 'static,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
for _ in 0..self.copy_loops_per_yield {
trace!("polling received body with state {:?}", &*self.state);
let ret = match *self.state {
Start => self.handle_start(),
Chunked { remaining, total } => self.handle_chunked(cx, buf, remaining, total),
PartialChunkSize { total } => self.handle_partial(cx, buf, total),
FixedLength {
current_index,
total: total_length,
} => self.handle_fixed_length(cx, buf, current_index, total_length),
End => Ready(Ok((End, 0, None))),
};
let (new_body_state, bytes, unused) = ready!(ret)?;
if let Some(unused) = unused {
if let Some(existing) = &mut *self.buffer {
existing.extend_from_slice(&unused);
} else {
*self.buffer = Some(unused);
}
}
*self.state = new_body_state;
if *self.state == End {
if self.on_completion.is_some() && self.owns_transport() {
let transport = self.transport.take().unwrap().unwrap_owned();
let on_completion = self.on_completion.take().unwrap();
on_completion(transport);
}
return Ready(Ok(bytes));
} else if bytes != 0 {
return Ready(Ok(bytes));
}
}
cx.waker().wake_by_ref();
Pending
}
}
impl<'conn, Transport> Debug for ReceivedBody<'conn, Transport> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("RequestBody")
.field("state", &*self.state)
.field("content_length", &self.content_length)
.field(
"buffer",
&self.buffer.as_deref().map(String::from_utf8_lossy),
)
.field("on_completion", &self.on_completion.is_some())
.finish()
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
#[allow(missing_docs)]
#[doc(hidden)]
pub enum ReceivedBodyState {
#[default]
Start,
Chunked {
remaining: u64,
total: u64,
},
PartialChunkSize {
total: u64,
},
FixedLength {
current_index: u64,
total: u64,
},
End,
}
impl<Transport> From<ReceivedBody<'static, Transport>> for Body
where
Transport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
fn from(rb: ReceivedBody<'static, Transport>) -> Self {
let len = rb.content_length;
Body::new_streaming(rb, len)
}
}