Skip to main content

tower_request_guard/
guard.rs

1use crate::violation::OnViolation;
2use std::time::Duration;
3
4/// Resolved guard configuration (immutable after build).
5#[derive(Debug, Clone)]
6pub struct GuardConfig {
7    pub(crate) max_body_size: Option<u64>,
8    pub(crate) timeout: Option<Duration>,
9    pub(crate) allowed_content_types: Option<Vec<String>>,
10    pub(crate) required_headers: Vec<String>,
11    #[cfg(feature = "json")]
12    pub(crate) max_json_depth: Option<u32>,
13}
14
15/// The built guard holding config and violation policy.
16#[derive(Debug, Clone)]
17pub struct RequestGuard {
18    pub(crate) config: GuardConfig,
19    pub(crate) on_violation: OnViolation,
20}
21
22impl RequestGuard {
23    /// Create a new builder for configuring a request guard.
24    pub fn builder() -> RequestGuardBuilder {
25        RequestGuardBuilder::default()
26    }
27
28    /// Create a Tower layer from this guard.
29    pub fn layer(self) -> crate::layer::RequestGuardLayer {
30        crate::layer::RequestGuardLayer::new(self)
31    }
32}
33
34/// Builder for RequestGuard.
35#[derive(Default)]
36pub struct RequestGuardBuilder {
37    max_body_size: Option<u64>,
38    timeout: Option<Duration>,
39    allowed_content_types: Option<Vec<String>>,
40    required_headers: Vec<String>,
41    on_violation: OnViolation,
42    #[cfg(feature = "json")]
43    max_json_depth: Option<u32>,
44}
45
46impl RequestGuardBuilder {
47    /// Set the maximum allowed body size in bytes.
48    pub fn max_body_size(mut self, size: u64) -> Self {
49        self.max_body_size = Some(size);
50        self
51    }
52
53    /// Set the per-request timeout duration.
54    pub fn timeout(mut self, duration: Duration) -> Self {
55        self.timeout = Some(duration);
56        self
57    }
58
59    /// Set the allowed Content-Type media types (e.g. `"application/json"`).
60    pub fn allowed_content_types<I, S>(mut self, types: I) -> Self
61    where
62        I: IntoIterator<Item = S>,
63        S: Into<String>,
64    {
65        self.allowed_content_types = Some(types.into_iter().map(Into::into).collect());
66        self
67    }
68
69    /// Require a header to be present on every request.
70    pub fn require_header(mut self, name: impl Into<String>) -> Self {
71        self.required_headers.push(name.into());
72        self
73    }
74
75    /// Set the violation handling policy (default: [`OnViolation::Reject`]).
76    pub fn on_violation(mut self, policy: OnViolation) -> Self {
77        self.on_violation = policy;
78        self
79    }
80
81    /// Set the maximum allowed JSON nesting depth (requires `json` feature).
82    #[cfg(feature = "json")]
83    pub fn max_json_depth(mut self, depth: u32) -> Self {
84        self.max_json_depth = Some(depth);
85        self
86    }
87
88    /// Build the [`RequestGuard`] with the configured settings.
89    pub fn build(self) -> RequestGuard {
90        RequestGuard {
91            config: GuardConfig {
92                max_body_size: self.max_body_size,
93                timeout: self.timeout,
94                allowed_content_types: self.allowed_content_types,
95                required_headers: self.required_headers,
96                #[cfg(feature = "json")]
97                max_json_depth: self.max_json_depth,
98            },
99            on_violation: self.on_violation,
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use std::time::Duration;
108
109    #[test]
110    fn builder_defaults() {
111        let guard = RequestGuard::builder().build();
112        assert_eq!(guard.config.max_body_size, None);
113        assert_eq!(guard.config.timeout, None);
114        assert!(guard.config.allowed_content_types.is_none());
115        assert!(guard.config.required_headers.is_empty());
116        assert!(matches!(guard.on_violation, OnViolation::Reject));
117    }
118
119    #[test]
120    fn builder_sets_all_fields() {
121        let guard = RequestGuard::builder()
122            .max_body_size(1_048_576)
123            .timeout(Duration::from_secs(30))
124            .allowed_content_types(["application/json"])
125            .require_header("Authorization")
126            .require_header("X-Request-Id")
127            .on_violation(OnViolation::LogAndPass)
128            .build();
129
130        assert_eq!(guard.config.max_body_size, Some(1_048_576));
131        assert_eq!(guard.config.timeout, Some(Duration::from_secs(30)));
132        assert_eq!(
133            guard.config.allowed_content_types,
134            Some(vec!["application/json".to_string()])
135        );
136        assert_eq!(
137            guard.config.required_headers,
138            vec!["Authorization".to_string(), "X-Request-Id".to_string()]
139        );
140        assert!(matches!(guard.on_violation, OnViolation::LogAndPass));
141    }
142
143    #[test]
144    fn builder_chaining() {
145        let _guard = RequestGuard::builder()
146            .max_body_size(1024)
147            .timeout(Duration::from_secs(5))
148            .require_header("Auth")
149            .build();
150    }
151}