tower_request_guard/
route.rs1use crate::guard::GuardConfig;
2use http::Request;
3use std::task::{Context, Poll};
4use std::time::Duration;
5use tower_layer::Layer;
6use tower_service::Service;
7
8#[derive(Debug, Clone, Default)]
11pub struct RouteGuardConfig {
12 pub(crate) max_body_size: Option<u64>,
13 pub(crate) timeout: Option<Duration>,
14 pub(crate) allowed_content_types: Option<Vec<String>>,
15 pub(crate) skip_headers: Vec<String>,
16 pub(crate) extra_required_headers: Vec<String>,
17 pub(crate) skip_all: bool,
18 #[cfg(feature = "json")]
19 pub(crate) max_json_depth: Option<u32>,
20}
21
22impl RouteGuardConfig {
23 pub fn max_body_size(mut self, size: u64) -> Self {
25 self.max_body_size = Some(size);
26 self
27 }
28
29 pub fn timeout(mut self, duration: Duration) -> Self {
31 self.timeout = Some(duration);
32 self
33 }
34
35 pub fn allowed_content_types<I, S>(mut self, types: I) -> Self
37 where
38 I: IntoIterator<Item = S>,
39 S: Into<String>,
40 {
41 self.allowed_content_types = Some(types.into_iter().map(Into::into).collect());
42 self
43 }
44
45 pub fn skip_header(mut self, name: impl Into<String>) -> Self {
47 self.skip_headers.push(name.into());
48 self
49 }
50
51 pub fn require_header(mut self, name: impl Into<String>) -> Self {
53 self.extra_required_headers.push(name.into());
54 self
55 }
56
57 pub fn skip_all(mut self) -> Self {
59 self.skip_all = true;
60 self
61 }
62
63 #[cfg(feature = "json")]
65 pub fn max_json_depth(mut self, depth: u32) -> Self {
66 self.max_json_depth = Some(depth);
67 self
68 }
69
70 pub fn merge_with(&self, global: &GuardConfig) -> GuardConfig {
73 if self.skip_all {
74 return GuardConfig {
75 max_body_size: None,
76 timeout: None,
77 allowed_content_types: None,
78 required_headers: Vec::new(),
79 #[cfg(feature = "json")]
80 max_json_depth: None,
81 };
82 }
83
84 let mut required_headers = global.required_headers.clone();
86 required_headers.retain(|h| !self.skip_headers.iter().any(|s| s.eq_ignore_ascii_case(h)));
87 for extra in &self.extra_required_headers {
88 if !required_headers
89 .iter()
90 .any(|h| h.eq_ignore_ascii_case(extra))
91 {
92 required_headers.push(extra.clone());
93 }
94 }
95
96 GuardConfig {
97 max_body_size: self.max_body_size.or(global.max_body_size),
98 timeout: self.timeout.or(global.timeout),
99 allowed_content_types: self
100 .allowed_content_types
101 .clone()
102 .or_else(|| global.allowed_content_types.clone()),
103 required_headers,
104 #[cfg(feature = "json")]
105 max_json_depth: self.max_json_depth.or(global.max_json_depth),
106 }
107 }
108}
109
110pub fn route_guard<F>(f: F) -> RouteGuardLayer
113where
114 F: FnOnce(RouteGuardConfig) -> RouteGuardConfig,
115{
116 RouteGuardLayer(f(RouteGuardConfig::default()))
117}
118
119#[derive(Debug, Clone)]
121pub struct RouteGuardLayer(RouteGuardConfig);
122
123impl<S> Layer<S> for RouteGuardLayer {
124 type Service = RouteGuardInsertService<S>;
125
126 fn layer(&self, inner: S) -> Self::Service {
127 RouteGuardInsertService {
128 inner,
129 config: self.0.clone(),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct RouteGuardInsertService<S> {
137 inner: S,
138 config: RouteGuardConfig,
139}
140
141impl<S, B> Service<Request<B>> for RouteGuardInsertService<S>
142where
143 S: Service<Request<B>>,
144{
145 type Response = S::Response;
146 type Error = S::Error;
147 type Future = S::Future;
148
149 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 self.inner.poll_ready(cx)
151 }
152
153 fn call(&mut self, mut req: Request<B>) -> Self::Future {
154 req.extensions_mut().insert(self.config.clone());
155 self.inner.call(req)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::guard::GuardConfig;
163 use std::time::Duration;
164
165 fn base_config() -> GuardConfig {
166 GuardConfig {
167 max_body_size: Some(1024),
168 timeout: Some(Duration::from_secs(30)),
169 allowed_content_types: Some(vec!["application/json".into()]),
170 required_headers: vec!["Authorization".into(), "X-Request-Id".into()],
171 #[cfg(feature = "json")]
172 max_json_depth: Some(32),
173 }
174 }
175
176 #[test]
177 fn merge_overrides_numeric_values() {
178 let route = RouteGuardConfig {
179 max_body_size: Some(2048),
180 timeout: Some(Duration::from_secs(60)),
181 ..Default::default()
182 };
183 let merged = route.merge_with(&base_config());
184 assert_eq!(merged.max_body_size, Some(2048));
185 assert_eq!(merged.timeout, Some(Duration::from_secs(60)));
186 }
187
188 #[test]
189 fn merge_replaces_content_types() {
190 let route = RouteGuardConfig {
191 allowed_content_types: Some(vec!["multipart/form-data".into()]),
192 ..Default::default()
193 };
194 let merged = route.merge_with(&base_config());
195 assert_eq!(
196 merged.allowed_content_types,
197 Some(vec!["multipart/form-data".into()])
198 );
199 }
200
201 #[test]
202 fn merge_skip_header_removes() {
203 let route = RouteGuardConfig {
204 skip_headers: vec!["Authorization".into()],
205 ..Default::default()
206 };
207 let merged = route.merge_with(&base_config());
208 assert_eq!(merged.required_headers, vec!["X-Request-Id".to_string()]);
209 }
210
211 #[test]
212 fn merge_require_header_adds() {
213 let route = RouteGuardConfig {
214 extra_required_headers: vec!["X-Tenant-Id".into()],
215 ..Default::default()
216 };
217 let merged = route.merge_with(&base_config());
218 assert!(merged.required_headers.contains(&"X-Tenant-Id".to_string()));
219 assert!(merged
220 .required_headers
221 .contains(&"Authorization".to_string()));
222 }
223
224 #[test]
225 fn merge_skip_all_clears_everything() {
226 let route = RouteGuardConfig {
227 skip_all: true,
228 ..Default::default()
229 };
230 let merged = route.merge_with(&base_config());
231 assert_eq!(merged.max_body_size, None);
232 assert_eq!(merged.timeout, None);
233 assert!(merged.allowed_content_types.is_none());
234 assert!(merged.required_headers.is_empty());
235 }
236
237 #[test]
238 fn merge_inherits_unset_values() {
239 let route = RouteGuardConfig::default();
240 let merged = route.merge_with(&base_config());
241 assert_eq!(merged.max_body_size, Some(1024));
242 assert_eq!(merged.timeout, Some(Duration::from_secs(30)));
243 }
244}