volo_http/server/layer/
body_limit.rs

1use http::StatusCode;
2use http_body::Body;
3use motore::{Service, layer::Layer};
4
5use crate::{context::ServerContext, request::Request, response::Response, server::IntoResponse};
6
7/// [`Layer`] for limiting body size
8///
9/// See [`BodyLimitLayer::new`] for more details.
10#[derive(Clone)]
11pub struct BodyLimitLayer {
12    limit: usize,
13}
14
15impl BodyLimitLayer {
16    /// Create a new [`BodyLimitLayer`] with given `body_limit`.
17    ///
18    /// If the Body is larger than the `body_limit`, the request will be rejected.
19    pub fn new(body_limit: usize) -> Self {
20        Self { limit: body_limit }
21    }
22}
23
24impl<S> Layer<S> for BodyLimitLayer {
25    type Service = BodyLimitService<S>;
26
27    fn layer(self, inner: S) -> Self::Service {
28        BodyLimitService {
29            service: inner,
30            limit: self.limit,
31        }
32    }
33}
34
35/// [`BodyLimitLayer`] generated [`Service`]
36///
37/// See [`BodyLimitLayer`] for more details.
38pub struct BodyLimitService<S> {
39    service: S,
40    limit: usize,
41}
42
43impl<S, B> Service<ServerContext, Request<B>> for BodyLimitService<S>
44where
45    S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
46    S::Response: IntoResponse,
47    B: Body + Send,
48{
49    type Response = Response;
50    type Error = S::Error;
51
52    async fn call(
53        &self,
54        cx: &mut ServerContext,
55        req: Request<B>,
56    ) -> Result<Self::Response, Self::Error> {
57        let (parts, body) = req.into_parts();
58        // get body size from content length
59        if let Some(size) = parts
60            .headers
61            .get(http::header::CONTENT_LENGTH)
62            .and_then(|v| v.to_str().ok().and_then(|s| s.parse::<usize>().ok()))
63        {
64            if size > self.limit {
65                return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
66            }
67        } else {
68            // get body size from stream
69            if body.size_hint().lower() > self.limit as u64 {
70                return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
71            }
72        }
73
74        let req = Request::from_parts(parts, body);
75        Ok(self.service.call(cx, req).await?.into_response())
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use http::{Method, StatusCode};
82    use motore::{Service, layer::Layer};
83
84    use crate::{
85        server::{
86            layer::BodyLimitLayer,
87            route::{Route, any},
88            test_helpers::empty_cx,
89        },
90        utils::test_helpers::simple_req,
91    };
92
93    #[tokio::test]
94    async fn test_body_limit() {
95        async fn handler() -> &'static str {
96            "Hello, World"
97        }
98
99        let body_limit_layer = BodyLimitLayer::new(8);
100        let route: Route<_> = Route::new(any(handler));
101        let service = body_limit_layer.layer(route);
102
103        let mut cx = empty_cx();
104
105        // Test case 1: reject
106        let req = simple_req(Method::GET, "/", "111111111".to_string());
107        let res = service.call(&mut cx, req).await.unwrap();
108        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
109
110        // Test case 2: not reject
111        let req = simple_req(Method::GET, "/", "1".to_string());
112        let res = service.call(&mut cx, req).await.unwrap();
113        assert_eq!(res.status(), StatusCode::OK);
114    }
115}