volo_http/server/layer/
body_limit.rs1use http::StatusCode;
2use http_body::Body;
3use motore::{Service, layer::Layer};
4
5use crate::{context::ServerContext, request::Request, response::Response, server::IntoResponse};
6
7#[derive(Clone)]
11pub struct BodyLimitLayer {
12 limit: usize,
13}
14
15impl BodyLimitLayer {
16 pub fn new(body_limit: usize) -> Self {
20 Self { limit: body_limit }
21 }
22}
23
24impl<S> Layer<S> for BodyLimitLayer {
25 type Service = BodyLimitService<S>;
26
27 fn layer(self, inner: S) -> Self::Service {
28 BodyLimitService {
29 service: inner,
30 limit: self.limit,
31 }
32 }
33}
34
35pub struct BodyLimitService<S> {
39 service: S,
40 limit: usize,
41}
42
43impl<S, B> Service<ServerContext, Request<B>> for BodyLimitService<S>
44where
45 S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
46 S::Response: IntoResponse,
47 B: Body + Send,
48{
49 type Response = Response;
50 type Error = S::Error;
51
52 async fn call(
53 &self,
54 cx: &mut ServerContext,
55 req: Request<B>,
56 ) -> Result<Self::Response, Self::Error> {
57 let (parts, body) = req.into_parts();
58 if let Some(size) = parts
60 .headers
61 .get(http::header::CONTENT_LENGTH)
62 .and_then(|v| v.to_str().ok().and_then(|s| s.parse::<usize>().ok()))
63 {
64 if size > self.limit {
65 return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
66 }
67 } else {
68 if body.size_hint().lower() > self.limit as u64 {
70 return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
71 }
72 }
73
74 let req = Request::from_parts(parts, body);
75 Ok(self.service.call(cx, req).await?.into_response())
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use http::{Method, StatusCode};
82 use motore::{Service, layer::Layer};
83
84 use crate::{
85 server::{
86 layer::BodyLimitLayer,
87 route::{Route, any},
88 test_helpers::empty_cx,
89 },
90 utils::test_helpers::simple_req,
91 };
92
93 #[tokio::test]
94 async fn test_body_limit() {
95 async fn handler() -> &'static str {
96 "Hello, World"
97 }
98
99 let body_limit_layer = BodyLimitLayer::new(8);
100 let route: Route<_> = Route::new(any(handler));
101 let service = body_limit_layer.layer(route);
102
103 let mut cx = empty_cx();
104
105 let req = simple_req(Method::GET, "/", "111111111".to_string());
107 let res = service.call(&mut cx, req).await.unwrap();
108 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
109
110 let req = simple_req(Method::GET, "/", "1".to_string());
112 let res = service.call(&mut cx, req).await.unwrap();
113 assert_eq!(res.status(), StatusCode::OK);
114 }
115}