tower_request_guard/
guard.rs1use crate::violation::OnViolation;
2use std::time::Duration;
3
4#[derive(Debug, Clone)]
6pub struct GuardConfig {
7 pub(crate) max_body_size: Option<u64>,
8 pub(crate) timeout: Option<Duration>,
9 pub(crate) allowed_content_types: Option<Vec<String>>,
10 pub(crate) required_headers: Vec<String>,
11 #[cfg(feature = "json")]
12 pub(crate) max_json_depth: Option<u32>,
13}
14
15#[derive(Debug, Clone)]
17pub struct RequestGuard {
18 pub(crate) config: GuardConfig,
19 pub(crate) on_violation: OnViolation,
20}
21
22impl RequestGuard {
23 pub fn builder() -> RequestGuardBuilder {
25 RequestGuardBuilder::default()
26 }
27
28 pub fn layer(self) -> crate::layer::RequestGuardLayer {
30 crate::layer::RequestGuardLayer::new(self)
31 }
32}
33
34#[derive(Default)]
36pub struct RequestGuardBuilder {
37 max_body_size: Option<u64>,
38 timeout: Option<Duration>,
39 allowed_content_types: Option<Vec<String>>,
40 required_headers: Vec<String>,
41 on_violation: OnViolation,
42 #[cfg(feature = "json")]
43 max_json_depth: Option<u32>,
44}
45
46impl RequestGuardBuilder {
47 pub fn max_body_size(mut self, size: u64) -> Self {
49 self.max_body_size = Some(size);
50 self
51 }
52
53 pub fn timeout(mut self, duration: Duration) -> Self {
55 self.timeout = Some(duration);
56 self
57 }
58
59 pub fn allowed_content_types<I, S>(mut self, types: I) -> Self
61 where
62 I: IntoIterator<Item = S>,
63 S: Into<String>,
64 {
65 self.allowed_content_types = Some(types.into_iter().map(Into::into).collect());
66 self
67 }
68
69 pub fn require_header(mut self, name: impl Into<String>) -> Self {
71 self.required_headers.push(name.into());
72 self
73 }
74
75 pub fn on_violation(mut self, policy: OnViolation) -> Self {
77 self.on_violation = policy;
78 self
79 }
80
81 #[cfg(feature = "json")]
83 pub fn max_json_depth(mut self, depth: u32) -> Self {
84 self.max_json_depth = Some(depth);
85 self
86 }
87
88 pub fn build(self) -> RequestGuard {
90 RequestGuard {
91 config: GuardConfig {
92 max_body_size: self.max_body_size,
93 timeout: self.timeout,
94 allowed_content_types: self.allowed_content_types,
95 required_headers: self.required_headers,
96 #[cfg(feature = "json")]
97 max_json_depth: self.max_json_depth,
98 },
99 on_violation: self.on_violation,
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use std::time::Duration;
108
109 #[test]
110 fn builder_defaults() {
111 let guard = RequestGuard::builder().build();
112 assert_eq!(guard.config.max_body_size, None);
113 assert_eq!(guard.config.timeout, None);
114 assert!(guard.config.allowed_content_types.is_none());
115 assert!(guard.config.required_headers.is_empty());
116 assert!(matches!(guard.on_violation, OnViolation::Reject));
117 }
118
119 #[test]
120 fn builder_sets_all_fields() {
121 let guard = RequestGuard::builder()
122 .max_body_size(1_048_576)
123 .timeout(Duration::from_secs(30))
124 .allowed_content_types(["application/json"])
125 .require_header("Authorization")
126 .require_header("X-Request-Id")
127 .on_violation(OnViolation::LogAndPass)
128 .build();
129
130 assert_eq!(guard.config.max_body_size, Some(1_048_576));
131 assert_eq!(guard.config.timeout, Some(Duration::from_secs(30)));
132 assert_eq!(
133 guard.config.allowed_content_types,
134 Some(vec!["application/json".to_string()])
135 );
136 assert_eq!(
137 guard.config.required_headers,
138 vec!["Authorization".to_string(), "X-Request-Id".to_string()]
139 );
140 assert!(matches!(guard.on_violation, OnViolation::LogAndPass));
141 }
142
143 #[test]
144 fn builder_chaining() {
145 let _guard = RequestGuard::builder()
146 .max_body_size(1024)
147 .timeout(Duration::from_secs(5))
148 .require_header("Auth")
149 .build();
150 }
151}