viz_core/middleware/
limits.rs

1//! Limits Middleware.
2
3#[cfg(feature = "multipart")]
4use std::sync::Arc;
5
6use crate::{Handler, IntoResponse, Request, Response, Result, Transform, types};
7
8/// A configuration for [`LimitsMiddleware`].
9#[derive(Clone, Debug)]
10pub struct Config {
11    limits: types::Limits,
12    #[cfg(feature = "multipart")]
13    multipart: Arc<types::MultipartLimits>,
14}
15
16impl Config {
17    /// Creates a new Config.
18    #[must_use]
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Sets a limits for the Text/Bytes/Form.
24    #[must_use]
25    pub fn limits(mut self, limits: types::Limits) -> Self {
26        self.limits = limits.sort();
27        self
28    }
29
30    /// Sets a limits for the Multipart Form.
31    #[cfg(feature = "multipart")]
32    #[must_use]
33    pub fn multipart(mut self, limits: types::MultipartLimits) -> Self {
34        *Arc::make_mut(&mut self.multipart) = limits;
35        self
36    }
37}
38
39impl Default for Config {
40    fn default() -> Self {
41        Self {
42            limits: types::Limits::default(),
43            #[cfg(feature = "multipart")]
44            multipart: Arc::new(types::MultipartLimits::default()),
45        }
46    }
47}
48
49impl<H> Transform<H> for Config
50where
51    H: Clone,
52{
53    type Output = LimitsMiddleware<H>;
54
55    fn transform(&self, h: H) -> Self::Output {
56        LimitsMiddleware {
57            h,
58            config: self.clone(),
59        }
60    }
61}
62
63/// Limits middleware.
64#[derive(Clone, Debug)]
65pub struct LimitsMiddleware<H> {
66    h: H,
67    config: Config,
68}
69
70#[crate::async_trait]
71impl<H, O> Handler<Request> for LimitsMiddleware<H>
72where
73    H: Handler<Request, Output = Result<O>>,
74    O: IntoResponse,
75{
76    type Output = Result<Response>;
77
78    async fn call(&self, mut req: Request) -> Self::Output {
79        req.extensions_mut().insert(self.config.limits.clone());
80        #[cfg(feature = "multipart")]
81        req.extensions_mut().insert(self.config.multipart.clone());
82
83        self.h.call(req).await.map(IntoResponse::into_response)
84    }
85}