1use std::collections::HashSet;
58use std::net::IpAddr;
59use std::str::FromStr;
60use crate::{Request, Response, middleware::Middleware};
61
62#[cfg(feature = "security")]
63use {
64 hmac::{Hmac, Mac},
65 sha2::Sha256,
66 base64::{Engine as _, engine::general_purpose},
67 uuid::Uuid,
68};
69
70pub struct RequestSigning {
143 #[cfg(feature = "security")]
144 secret: Vec<u8>,
145 #[cfg(not(feature = "security"))]
146 _phantom: std::marker::PhantomData<()>,
147}
148
149impl RequestSigning {
150 #[cfg(feature = "security")]
151 pub fn new(secret: &str) -> Self {
152 Self {
153 secret: secret.as_bytes().to_vec(),
154 }
155 }
156
157 #[cfg(not(feature = "security"))]
158 pub fn new(_secret: &str) -> Self {
159 Self {
160 _phantom: std::marker::PhantomData,
161 }
162 }
163}
164
165impl Middleware for RequestSigning {
166 fn call(
167 &self,
168 req: Request,
169 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
170 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
171 #[cfg(feature = "security")]
172 {
173 let secret = self.secret.clone();
174 Box::pin(async move {
175 if let Some(signature) = req.header("X-Signature") {
177 let body = req.body();
178 let timestamp = req.header("X-Timestamp").unwrap_or("0");
179
180 let payload = format!("{}{}", timestamp, std::str::from_utf8(body).unwrap_or(""));
181
182 let mut mac = Hmac::<Sha256>::new_from_slice(&secret)
183 .expect("HMAC can take key of any size");
184 mac.update(payload.as_bytes());
185 let expected = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
186
187 if signature != expected {
188 return Response::with_status(http::StatusCode::UNAUTHORIZED)
189 .body("Invalid signature");
190 }
191 } else {
192 return Response::with_status(http::StatusCode::UNAUTHORIZED)
193 .body("Missing signature");
194 }
195
196 next(req).await
197 })
198 }
199
200 #[cfg(not(feature = "security"))]
201 {
202 Box::pin(async move {
203 next(req).await
204 })
205 }
206 }
207}
208
209pub struct IpWhitelist {
211 allowed_ips: HashSet<IpAddr>,
212 allowed_ranges: Vec<(IpAddr, u8)>, }
214
215impl IpWhitelist {
216 pub fn new() -> Self {
217 Self {
218 allowed_ips: HashSet::new(),
219 allowed_ranges: Vec::new(),
220 }
221 }
222
223 pub fn allow_ip(mut self, ip: &str) -> Self {
224 if let Ok(ip) = IpAddr::from_str(ip) {
225 self.allowed_ips.insert(ip);
226 }
227 self
228 }
229
230 pub fn allow_range(mut self, range: &str) -> Self {
231 if let Some((ip_str, prefix_str)) = range.split_once('/') {
232 if let (Ok(ip), Ok(prefix)) = (IpAddr::from_str(ip_str), prefix_str.parse::<u8>()) {
233 self.allowed_ranges.push((ip, prefix));
234 }
235 }
236 self
237 }
238
239 #[allow(dead_code)]
240 fn is_ip_allowed(&self, client_ip: IpAddr) -> bool {
241 if self.allowed_ips.contains(&client_ip) {
243 return true;
244 }
245
246 for (range_ip, prefix) in &self.allowed_ranges {
248 if self.ip_in_range(client_ip, *range_ip, *prefix) {
249 return true;
250 }
251 }
252
253 false
254 }
255
256 #[allow(dead_code)]
257 fn ip_in_range(&self, ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
258 match (ip, range_ip) {
259 (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
260 let ip_bits = u32::from(ip);
261 let range_bits = u32::from(range_ip);
262 let mask = !((1u32 << (32 - prefix)) - 1);
263 (ip_bits & mask) == (range_bits & mask)
264 }
265 (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
266 let ip_bits = u128::from(ip);
267 let range_bits = u128::from(range_ip);
268 let mask = !((1u128 << (128 - prefix)) - 1);
269 (ip_bits & mask) == (range_bits & mask)
270 }
271 _ => false,
272 }
273 }
274}
275
276fn is_ip_allowed_static(
277 client_ip: IpAddr,
278 allowed_ips: &HashSet<IpAddr>,
279 allowed_ranges: &[(IpAddr, u8)]
280) -> bool {
281 if allowed_ips.contains(&client_ip) {
283 return true;
284 }
285
286 for (range_ip, prefix) in allowed_ranges {
288 if ip_in_range_static(client_ip, *range_ip, *prefix) {
289 return true;
290 }
291 }
292
293 false
294}
295
296fn ip_in_range_static(ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
297 match (ip, range_ip) {
298 (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
299 let ip_bits = u32::from(ip);
300 let range_bits = u32::from(range_ip);
301 let mask = !((1u32 << (32 - prefix)) - 1);
302 (ip_bits & mask) == (range_bits & mask)
303 }
304 (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
305 let ip_bits = u128::from(ip);
306 let range_bits = u128::from(range_ip);
307 let mask = !((1u128 << (128 - prefix)) - 1);
308 (ip_bits & mask) == (range_bits & mask)
309 }
310 _ => false,
311 }
312}
313
314impl Middleware for IpWhitelist {
315 fn call(
316 &self,
317 req: Request,
318 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
319 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
320 let allowed_ips = self.allowed_ips.clone();
321 let allowed_ranges = self.allowed_ranges.clone();
322
323 Box::pin(async move {
324 let client_ip = req.header("X-Forwarded-For")
326 .or_else(|| req.header("X-Real-IP"))
327 .and_then(|ip_str| IpAddr::from_str(ip_str).ok());
328
329 if let Some(client_ip) = client_ip {
330 if !is_ip_allowed_static(client_ip, &allowed_ips, &allowed_ranges) {
331 return Response::with_status(http::StatusCode::FORBIDDEN)
332 .body("IP address not allowed");
333 }
334 } else {
335 return Response::with_status(http::StatusCode::BAD_REQUEST)
336 .body("Unable to determine client IP");
337 }
338
339 next(req).await
340 })
341 }
342}
343
344pub struct RequestId;
346
347impl Middleware for RequestId {
348 fn call(
349 &self,
350 req: Request,
351 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
352 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
353 Box::pin(async move {
354 let request_id = req.header("X-Request-ID")
356 .map(|id| id.to_string())
357 .unwrap_or_else(|| {
358 #[cfg(feature = "security")]
359 {
360 Uuid::new_v4().to_string()
361 }
362 #[cfg(not(feature = "security"))]
363 {
364 format!("req_{}", std::time::SystemTime::now()
365 .duration_since(std::time::UNIX_EPOCH)
366 .unwrap_or_default()
367 .as_millis())
368 }
369 });
370
371 let mut response = next(req).await;
375 response = response.header("X-Request-ID", &request_id);
376 response
377 })
378 }
379}
380
381pub struct InputValidator;
383
384impl InputValidator {
385 fn is_safe_input(input: &str) -> bool {
386 let sql_patterns = [
388 "union", "select", "insert", "update", "delete", "drop", "create", "alter",
389 "exec", "execute", "sp_", "xp_", "--", "/*", "*/", ";",
390 ];
391
392 let xss_patterns = [
394 "<script", "</script>", "javascript:", "onload=", "onerror=", "onclick=",
395 "onmouseover=", "onfocus=", "onblur=", "onchange=", "onsubmit=",
396 ];
397
398 let input_lower = input.to_lowercase();
399
400 for pattern in &sql_patterns {
402 if input_lower.contains(pattern) {
403 return false;
404 }
405 }
406
407 for pattern in &xss_patterns {
409 if input_lower.contains(pattern) {
410 return false;
411 }
412 }
413
414 if input.contains("../") || input.contains("..\\") {
416 return false;
417 }
418
419 if input.contains('\0') {
421 return false;
422 }
423
424 true
425 }
426
427 fn validate_request_data(req: &Request) -> Result<(), String> {
428 for (key, value) in req.query_params() {
430 if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
431 return Err(format!("Invalid query parameter: {}", key));
432 }
433 }
434
435 for (key, value) in req.params() {
437 if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
438 return Err(format!("Invalid path parameter: {}", key));
439 }
440 }
441
442 if let Ok(body_str) = req.body_string() {
444 if !Self::is_safe_input(&body_str) {
445 return Err("Invalid request body content".to_string());
446 }
447 }
448
449 Ok(())
450 }
451}
452
453impl Middleware for InputValidator {
454 fn call(
455 &self,
456 req: Request,
457 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
458 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
459 Box::pin(async move {
460 if let Err(error) = Self::validate_request_data(&req) {
462 return Response::with_status(http::StatusCode::BAD_REQUEST)
463 .body(format!("Input validation failed: {}", error));
464 }
465
466 next(req).await
467 })
468 }
469}
470
471pub struct SecurityHeaders {
473 content_security_policy: Option<String>,
474}
475
476impl SecurityHeaders {
477 pub fn new() -> Self {
478 Self {
479 content_security_policy: Some(
480 "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' https:; connect-src 'self'; frame-ancestors 'none';"
481 .to_string(),
482 ),
483 }
484 }
485
486 pub fn with_csp(mut self, csp: &str) -> Self {
487 self.content_security_policy = Some(csp.to_string());
488 self
489 }
490
491 pub fn without_csp(mut self) -> Self {
492 self.content_security_policy = None;
493 self
494 }
495}
496
497impl Middleware for SecurityHeaders {
498 fn call(
499 &self,
500 req: Request,
501 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
502 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
503 let csp = self.content_security_policy.clone();
504 Box::pin(async move {
505 let mut response = next(req).await;
506
507 response = response
509 .header("X-Content-Type-Options", "nosniff")
510 .header("X-Frame-Options", "DENY")
511 .header("X-XSS-Protection", "1; mode=block")
512 .header("Referrer-Policy", "strict-origin-when-cross-origin")
513 .header("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
514 .header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload");
515
516 if let Some(csp) = csp {
517 response = response.header("Content-Security-Policy", &csp);
518 }
519
520 response
521 })
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use std::net::Ipv4Addr;
529
530 #[test]
531 fn test_ip_whitelist() {
532 let whitelist = IpWhitelist::new()
533 .allow_ip("192.168.1.1")
534 .allow_range("10.0.0.0/8");
535
536 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
537 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
538 assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
539 assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))));
540 assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
541 }
542
543 #[test]
544 fn test_input_validation() {
545 assert!(!InputValidator::is_safe_input("'; DROP TABLE users; --"));
546 assert!(!InputValidator::is_safe_input("<script>alert('xss')</script>"));
547 assert!(!InputValidator::is_safe_input("../../../etc/passwd"));
548 assert!(!InputValidator::is_safe_input("test\0null"));
549 assert!(InputValidator::is_safe_input("normal input text"));
550 assert!(InputValidator::is_safe_input("user@example.com"));
551 }
552}