use std::convert::Infallible;
use std::fmt::{self, Debug};
use headers::HeaderValue;
use http::StatusCode;
use http_body_util::{BodyExt, Full};
use http_error::HttpError;
use hyper::body::{Buf, Bytes, Incoming};
use mime::Mime;
pub use multer;
use multer::Constraints;
pub use multer::Multipart;
use serde::de::DeserializeOwned;
use crate::IntoResponse;
const DEFAULT_LIMIT: usize = 1024 * 1024; pub struct Body<B: http_body::Body<Data = Bytes, Error = E> = Incoming, E = hyper::Error>(
BodyState<B, E>,
);
#[derive(Default)]
enum BodyState<B: http_body::Body<Data = Bytes, Error = E> = Incoming, E = hyper::Error> {
#[default]
Empty,
Read,
Unread(BodyData<B, E>),
}
struct BodyData<B: http_body::Body<Data = Bytes, Error = E> = Incoming, E = hyper::Error> {
body: B,
content_type: Option<HeaderValue>,
len: Option<usize>,
limit: usize,
}
impl<B: http_body::Body<Data = Bytes, Error = E>, E> Body<B, E>
where
BodyError: From<E>,
{
pub(crate) fn new(body: B, len: Option<usize>, content_type: Option<HeaderValue>) -> Self {
Body(BodyState::Unread(BodyData {
body,
content_type,
len,
limit: DEFAULT_LIMIT,
}))
}
pub fn empty() -> Self {
Body(BodyState::Empty)
}
pub fn with_limit(&mut self, limit: usize) -> &mut Self {
if let Body(BodyState::Unread(ref mut inner)) = self {
inner.limit = limit
}
self
}
pub fn take(&mut self) -> Result<(B, Option<Mime>), BodyError> {
let state = std::mem::take(&mut self.0);
let data = match state {
BodyState::Empty => return Err(BodyError::Empty),
BodyState::Read => {
*self = Body(BodyState::Read);
return Err(BodyError::AlreadyRead);
}
BodyState::Unread(data) => {
*self = Body(BodyState::Read);
data
}
};
let len = data.len.ok_or(BodyError::ContentLengthMissing)?;
if len > data.limit {
return Err(BodyError::MaxSize);
}
Ok((
data.body,
data.content_type
.and_then(|v| v.to_str().ok()?.parse().ok()),
))
}
pub async fn json<T: DeserializeOwned>(&mut self) -> Result<T, BodyError> {
let (body, mime_type) = self.take()?;
if let Some(mime_type) = mime_type {
if mime_type != mime::APPLICATION_JSON {
return Err(BodyError::WrongContentType("application/json"));
}
}
let whole_body = body.collect().await?.aggregate();
let data: T = serde_json::from_reader(whole_body.reader())?;
Ok(data)
}
pub async fn form<T: DeserializeOwned>(&mut self) -> Result<T, BodyError> {
let (body, mime_type) = self.take()?;
if let Some(mime_type) = mime_type {
if mime_type != mime::APPLICATION_WWW_FORM_URLENCODED {
return Err(BodyError::WrongContentType(
"application/x-www-form-urlencoded",
));
}
}
let whole_body = body.collect().await?.aggregate();
let data: T = serde_urlencoded::from_reader(whole_body.reader())?;
Ok(data)
}
pub async fn bytes(&mut self) -> Result<impl Buf, BodyError> {
let (body, _) = self.take()?;
let data = body.collect().await?.aggregate();
Ok(data)
}
pub async fn multipart(&mut self) -> Result<Multipart<'_>, BodyError>
where
B: Send + Unpin,
E: std::error::Error + Send + Sync + 'static,
{
self.multipart_with_constraints(Default::default()).await
}
pub async fn multipart_with_constraints(
&mut self,
constraints: Constraints,
) -> Result<Multipart<'_>, BodyError>
where
B: Send + Unpin,
E: std::error::Error + Send + Sync + 'static,
{
let (body, mime) = self.take()?;
let boundary = mime
.and_then(|mime| multer::parse_boundary(mime).ok())
.ok_or(BodyError::WrongContentType("multipart/form-data"))?;
let body = futures_util::stream::try_unfold(body, |mut body| async move {
let Some(bytes) = body.frame().await else {
return Ok::<_, E>(None);
};
match bytes?.into_data() {
Ok(data) => Ok(Some((data, body))),
Err(_) => Ok(None),
}
});
Ok(Multipart::with_constraints(body, boundary, constraints))
}
}
#[derive(Debug, thiserror::Error)]
pub enum BodyError {
#[error("the requested exceeded the max accepted body size")]
MaxSize,
#[error("content-length header is required to safely read a body")]
ContentLengthMissing,
#[error(transparent)]
Hyper(#[from] hyper::Error),
#[error("error deserializing body as JSON")]
Json(#[from] serde_json::Error),
#[error("error deserializing body as application/x-www-form-urlencoded")]
Form(#[from] serde_urlencoded::de::Error),
#[error("body is empty")]
Empty,
#[error("body has already been read")]
AlreadyRead,
#[error("received wrong content type, expected: {0}")]
WrongContentType(&'static str),
}
impl HttpError for BodyError {
fn status_code(&self) -> StatusCode {
match self {
BodyError::MaxSize => StatusCode::PAYLOAD_TOO_LARGE,
BodyError::Empty | BodyError::ContentLengthMissing => StatusCode::BAD_REQUEST,
BodyError::Json(_)
| BodyError::Form(_)
| BodyError::Hyper(_)
| BodyError::AlreadyRead => StatusCode::INTERNAL_SERVER_ERROR,
BodyError::WrongContentType(_) => StatusCode::BAD_REQUEST,
}
}
fn reason(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BodyError::Json(_) => f.write_str("invalid JSON body"),
BodyError::Form(_) => f.write_str("invalid form body"),
err => err.fmt(f),
}
}
}
impl IntoResponse for BodyError {
fn into_response(self) -> crate::Response {
self.status_code().into_response()
}
}
impl From<Infallible> for BodyError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
impl From<&'static str> for Body<Full<Bytes>, Infallible> {
fn from(data: &'static str) -> Self {
Body::new(
Full::new(Bytes::from_static(data.as_bytes())),
Some(data.len()),
None,
)
}
}
impl From<String> for Body<Full<Bytes>, Infallible> {
fn from(data: String) -> Self {
let len = data.len();
Body::new(Full::new(Bytes::from(data)), Some(len), None)
}
}
impl From<&'static [u8]> for Body<Full<Bytes>, Infallible> {
fn from(data: &'static [u8]) -> Self {
Body::new(Full::new(Bytes::from_static(data)), Some(data.len()), None)
}
}
impl From<Vec<u8>> for Body<Full<Bytes>, Infallible> {
fn from(data: Vec<u8>) -> Self {
let len = data.len();
Body::new(Full::new(Bytes::from(data)), Some(len), None)
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use http_body_util::Full;
use super::*;
#[tokio::test]
async fn test_json_content_length_missing() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), None, None);
assert!(
matches!(
body.json::<i64>().await,
Err(BodyError::ContentLengthMissing)
),
"expected Err(BodyError::ContentLengthMissing)"
);
}
#[tokio::test]
async fn test_bytes_content_length_missing() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), None, None);
assert!(
matches!(body.bytes().await, Err(BodyError::ContentLengthMissing)),
"expected Err(BodyError::ContentLengthMissing)"
);
}
#[tokio::test]
async fn test_json_max_size() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
let body = body.with_limit(1);
assert!(
matches!(body.json::<i64>().await, Err(BodyError::MaxSize)),
"expected Err(BodyError::MaxSize)"
);
}
#[tokio::test]
async fn test_bytes_max_size() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
let body = body.with_limit(1);
assert!(
matches!(body.bytes().await, Err(BodyError::MaxSize)),
"expected Err(BodyError::MaxSize)"
);
}
#[tokio::test]
async fn test_json() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
assert_eq!(body.json::<i64>().await.unwrap(), (42))
}
#[tokio::test]
async fn test_bytes() {
use std::io::Read;
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
let mut reader = body.bytes().await.unwrap().reader();
let mut dst = [0; 8];
let n = reader.read(&mut dst).unwrap();
assert_eq!(&dst[..n], b"42")
}
#[tokio::test]
async fn test_json_already_read() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
body.json::<i64>().await.unwrap();
assert!(
matches!(body.json::<i64>().await, Err(BodyError::AlreadyRead)),
"expected Err(BodyError::AlreadyRead)"
);
}
#[tokio::test]
async fn test_bytes_already_read() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
body.bytes().await.unwrap();
assert!(
matches!(body.bytes().await, Err(BodyError::AlreadyRead)),
"expected Err(BodyError::AlreadyRead)"
);
}
#[tokio::test]
async fn test_json_empty() {
let mut body = Body::<Full<Bytes>, Infallible>::empty();
assert!(
matches!(body.json::<i64>().await, Err(BodyError::Empty)),
"expected Err(BodyError::Empty)"
);
}
#[tokio::test]
async fn test_bytes_empty() {
let mut body = Body::<Full<Bytes>, Infallible>::empty();
assert!(
matches!(body.bytes().await, Err(BodyError::Empty)),
"expected Err(BodyError::Empty)"
);
}
#[tokio::test]
async fn test_json_error() {
let mut body = Body::new(Full::new(Bytes::from_static(b"42")), Some(2), None);
assert!(
matches!(body.json::<String>().await, Err(BodyError::Json(_))),
"expected Err(BodyError::Json(_))"
);
}
}