1use std::collections::HashSet;
4use std::net::IpAddr;
5use std::str::FromStr;
6use crate::{Request, Response, middleware::Middleware};
7
8#[cfg(feature = "security")]
9use {
10 hmac::{Hmac, Mac},
11 sha2::Sha256,
12 base64::{Engine as _, engine::general_purpose},
13 uuid::Uuid,
14};
15
16pub struct RequestSigning {
18 #[cfg(feature = "security")]
19 secret: Vec<u8>,
20 #[cfg(not(feature = "security"))]
21 _phantom: std::marker::PhantomData<()>,
22}
23
24impl RequestSigning {
25 #[cfg(feature = "security")]
26 pub fn new(secret: &str) -> Self {
27 Self {
28 secret: secret.as_bytes().to_vec(),
29 }
30 }
31
32 #[cfg(not(feature = "security"))]
33 pub fn new(_secret: &str) -> Self {
34 Self {
35 _phantom: std::marker::PhantomData,
36 }
37 }
38}
39
40impl Middleware for RequestSigning {
41 fn call(
42 &self,
43 req: Request,
44 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
45 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
46 #[cfg(feature = "security")]
47 {
48 let secret = self.secret.clone();
49 Box::pin(async move {
50 if let Some(signature) = req.header("X-Signature") {
52 let body = req.body();
53 let timestamp = req.header("X-Timestamp").unwrap_or("0");
54
55 let payload = format!("{}{}", timestamp, std::str::from_utf8(body).unwrap_or(""));
56
57 let mut mac = Hmac::<Sha256>::new_from_slice(&secret)
58 .expect("HMAC can take key of any size");
59 mac.update(payload.as_bytes());
60 let expected = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
61
62 if signature != expected {
63 return Response::with_status(http::StatusCode::UNAUTHORIZED)
64 .body("Invalid signature");
65 }
66 } else {
67 return Response::with_status(http::StatusCode::UNAUTHORIZED)
68 .body("Missing signature");
69 }
70
71 next(req).await
72 })
73 }
74
75 #[cfg(not(feature = "security"))]
76 {
77 Box::pin(async move {
78 next(req).await
79 })
80 }
81 }
82}
83
84pub struct IpWhitelist {
86 allowed_ips: HashSet<IpAddr>,
87 allowed_ranges: Vec<(IpAddr, u8)>, }
89
90impl IpWhitelist {
91 pub fn new() -> Self {
92 Self {
93 allowed_ips: HashSet::new(),
94 allowed_ranges: Vec::new(),
95 }
96 }
97
98 pub fn allow_ip(mut self, ip: &str) -> Self {
99 if let Ok(ip) = IpAddr::from_str(ip) {
100 self.allowed_ips.insert(ip);
101 }
102 self
103 }
104
105 pub fn allow_range(mut self, range: &str) -> Self {
106 if let Some((ip_str, prefix_str)) = range.split_once('/') {
107 if let (Ok(ip), Ok(prefix)) = (IpAddr::from_str(ip_str), prefix_str.parse::<u8>()) {
108 self.allowed_ranges.push((ip, prefix));
109 }
110 }
111 self
112 }
113
114 #[allow(dead_code)]
115 fn is_ip_allowed(&self, client_ip: IpAddr) -> bool {
116 if self.allowed_ips.contains(&client_ip) {
118 return true;
119 }
120
121 for (range_ip, prefix) in &self.allowed_ranges {
123 if self.ip_in_range(client_ip, *range_ip, *prefix) {
124 return true;
125 }
126 }
127
128 false
129 }
130
131 #[allow(dead_code)]
132 fn ip_in_range(&self, ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
133 match (ip, range_ip) {
134 (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
135 let ip_bits = u32::from(ip);
136 let range_bits = u32::from(range_ip);
137 let mask = !((1u32 << (32 - prefix)) - 1);
138 (ip_bits & mask) == (range_bits & mask)
139 }
140 (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
141 let ip_bits = u128::from(ip);
142 let range_bits = u128::from(range_ip);
143 let mask = !((1u128 << (128 - prefix)) - 1);
144 (ip_bits & mask) == (range_bits & mask)
145 }
146 _ => false,
147 }
148 }
149}
150
151fn is_ip_allowed_static(
152 client_ip: IpAddr,
153 allowed_ips: &HashSet<IpAddr>,
154 allowed_ranges: &[(IpAddr, u8)]
155) -> bool {
156 if allowed_ips.contains(&client_ip) {
158 return true;
159 }
160
161 for (range_ip, prefix) in allowed_ranges {
163 if ip_in_range_static(client_ip, *range_ip, *prefix) {
164 return true;
165 }
166 }
167
168 false
169}
170
171fn ip_in_range_static(ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
172 match (ip, range_ip) {
173 (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
174 let ip_bits = u32::from(ip);
175 let range_bits = u32::from(range_ip);
176 let mask = !((1u32 << (32 - prefix)) - 1);
177 (ip_bits & mask) == (range_bits & mask)
178 }
179 (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
180 let ip_bits = u128::from(ip);
181 let range_bits = u128::from(range_ip);
182 let mask = !((1u128 << (128 - prefix)) - 1);
183 (ip_bits & mask) == (range_bits & mask)
184 }
185 _ => false,
186 }
187}
188
189impl Middleware for IpWhitelist {
190 fn call(
191 &self,
192 req: Request,
193 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
194 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
195 let allowed_ips = self.allowed_ips.clone();
196 let allowed_ranges = self.allowed_ranges.clone();
197
198 Box::pin(async move {
199 let client_ip = req.header("X-Forwarded-For")
201 .or_else(|| req.header("X-Real-IP"))
202 .and_then(|ip_str| IpAddr::from_str(ip_str).ok());
203
204 if let Some(client_ip) = client_ip {
205 if !is_ip_allowed_static(client_ip, &allowed_ips, &allowed_ranges) {
206 return Response::with_status(http::StatusCode::FORBIDDEN)
207 .body("IP address not allowed");
208 }
209 } else {
210 return Response::with_status(http::StatusCode::BAD_REQUEST)
211 .body("Unable to determine client IP");
212 }
213
214 next(req).await
215 })
216 }
217}
218
219pub struct RequestId;
221
222impl Middleware for RequestId {
223 fn call(
224 &self,
225 req: Request,
226 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
227 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
228 Box::pin(async move {
229 let request_id = req.header("X-Request-ID")
231 .map(|id| id.to_string())
232 .unwrap_or_else(|| {
233 #[cfg(feature = "security")]
234 {
235 Uuid::new_v4().to_string()
236 }
237 #[cfg(not(feature = "security"))]
238 {
239 format!("req_{}", std::time::SystemTime::now()
240 .duration_since(std::time::UNIX_EPOCH)
241 .unwrap_or_default()
242 .as_millis())
243 }
244 });
245
246 let mut response = next(req).await;
250 response = response.header("X-Request-ID", &request_id);
251 response
252 })
253 }
254}
255
256pub struct InputValidator;
258
259impl InputValidator {
260 fn is_safe_input(input: &str) -> bool {
261 let sql_patterns = [
263 "union", "select", "insert", "update", "delete", "drop", "create", "alter",
264 "exec", "execute", "sp_", "xp_", "--", "/*", "*/", ";",
265 ];
266
267 let xss_patterns = [
269 "<script", "</script>", "javascript:", "onload=", "onerror=", "onclick=",
270 "onmouseover=", "onfocus=", "onblur=", "onchange=", "onsubmit=",
271 ];
272
273 let input_lower = input.to_lowercase();
274
275 for pattern in &sql_patterns {
277 if input_lower.contains(pattern) {
278 return false;
279 }
280 }
281
282 for pattern in &xss_patterns {
284 if input_lower.contains(pattern) {
285 return false;
286 }
287 }
288
289 if input.contains("../") || input.contains("..\\") {
291 return false;
292 }
293
294 if input.contains('\0') {
296 return false;
297 }
298
299 true
300 }
301
302 fn validate_request_data(req: &Request) -> Result<(), String> {
303 for (key, value) in req.query_params() {
305 if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
306 return Err(format!("Invalid query parameter: {}", key));
307 }
308 }
309
310 for (key, value) in req.params() {
312 if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
313 return Err(format!("Invalid path parameter: {}", key));
314 }
315 }
316
317 if let Ok(body_str) = req.body_string() {
319 if !Self::is_safe_input(&body_str) {
320 return Err("Invalid request body content".to_string());
321 }
322 }
323
324 Ok(())
325 }
326}
327
328impl Middleware for InputValidator {
329 fn call(
330 &self,
331 req: Request,
332 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
333 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
334 Box::pin(async move {
335 if let Err(error) = Self::validate_request_data(&req) {
337 return Response::with_status(http::StatusCode::BAD_REQUEST)
338 .body(format!("Input validation failed: {}", error));
339 }
340
341 next(req).await
342 })
343 }
344}
345
346pub struct SecurityHeaders {
348 content_security_policy: Option<String>,
349}
350
351impl SecurityHeaders {
352 pub fn new() -> Self {
353 Self {
354 content_security_policy: Some(
355 "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' https:; connect-src 'self'; frame-ancestors 'none';"
356 .to_string(),
357 ),
358 }
359 }
360
361 pub fn with_csp(mut self, csp: &str) -> Self {
362 self.content_security_policy = Some(csp.to_string());
363 self
364 }
365
366 pub fn without_csp(mut self) -> Self {
367 self.content_security_policy = None;
368 self
369 }
370}
371
372impl Middleware for SecurityHeaders {
373 fn call(
374 &self,
375 req: Request,
376 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
377 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
378 let csp = self.content_security_policy.clone();
379 Box::pin(async move {
380 let mut response = next(req).await;
381
382 response = response
384 .header("X-Content-Type-Options", "nosniff")
385 .header("X-Frame-Options", "DENY")
386 .header("X-XSS-Protection", "1; mode=block")
387 .header("Referrer-Policy", "strict-origin-when-cross-origin")
388 .header("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
389 .header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload");
390
391 if let Some(csp) = csp {
392 response = response.header("Content-Security-Policy", &csp);
393 }
394
395 response
396 })
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use std::net::Ipv4Addr;
404
405 #[test]
406 fn test_ip_whitelist() {
407 let whitelist = IpWhitelist::new()
408 .allow_ip("192.168.1.1")
409 .allow_range("10.0.0.0/8");
410
411 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
412 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
413 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
414 assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))));
415 assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
416 }
417
418 #[test]
419 fn test_input_validation() {
420 assert!(!InputValidator::is_safe_input("'; DROP TABLE users; --"));
421 assert!(!InputValidator::is_safe_input("<script>alert('xss')</script>"));
422 assert!(!InputValidator::is_safe_input("../../../etc/passwd"));
423 assert!(!InputValidator::is_safe_input("test\0null"));
424 assert!(InputValidator::is_safe_input("normal input text"));
425 assert!(InputValidator::is_safe_input("user@example.com"));
426 }
427}