1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
//! size limiter middleware

use async_trait::async_trait;
use salvo_core::http::errors::*;
use salvo_core::http::HttpBody;
use salvo_core::prelude::*;
use salvo_core::routing::FlowCtrl;

/// MaxSizeHandler
pub struct MaxSizeHandler(u64);
#[async_trait]
impl Handler for MaxSizeHandler {
    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
        if let Some(upper) = req.body().and_then(|body| body.size_hint().upper()) {
            if upper > self.0 {
                res.set_http_error(PayloadTooLarge());
                ctrl.skip_reset();
            } else {
                ctrl.call_next(req, depot, res).await;
            }
        }
    }
}
/// Create a new ```MaxSizeHandler```.
pub fn max_size(size: u64) -> MaxSizeHandler {
    MaxSizeHandler(size)
}

#[cfg(test)]
mod tests {
    use salvo_core::hyper;
    use salvo_core::prelude::*;

    use super::*;

    #[fn_handler]
    async fn hello() -> &'static str {
        "hello"
    }

    #[tokio::test]
    async fn test_size_limiter() {
        let limit_handler = MaxSizeHandler(32);
        let router = Router::new()
            .hoop(limit_handler)
            .push(Router::with_path("hello").post(hello));
        let service = Service::new(router);

        let req: Request = hyper::Request::builder()
            .method("POST")
            .uri("http://127.0.0.1:7979/hello")
            .body("abc".into())
            .unwrap()
            .into();
        let content = service.handle(req).await.take_text().await.unwrap();
        assert_eq!(content, "hello");

        let req: Request = hyper::Request::builder()
            .method("POST")
            .uri("http://127.0.0.1:7979/hello")
            .body("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz".into())
            .unwrap()
            .into();
        let res = service.handle(req).await;
        assert_eq!(res.status_code().unwrap(), StatusCode::PAYLOAD_TOO_LARGE);
    }
}