torch_web/
security.rs

1//! # Security Middleware and Utilities
2//!
3//! This module provides comprehensive security middleware and utilities for protecting
4//! Torch web applications. It includes protection against common web vulnerabilities,
5//! request signing, IP whitelisting, security headers, and input validation.
6//!
7//! ## Available Security Features
8//!
9//! - **Security Headers**: Automatic security headers (HSTS, CSP, X-Frame-Options, etc.)
10//! - **Request Signing**: HMAC-based request authentication
11//! - **IP Whitelisting**: Restrict access to specific IP addresses or ranges
12//! - **Request ID**: Generate unique IDs for request tracking
13//! - **Input Validation**: Validate and sanitize user input
14//! - **Rate Limiting**: Protect against abuse and DoS attacks
15//!
16//! ## Security Best Practices
17//!
18//! ### 1. Always Use HTTPS in Production
19//! ```text
20//! Configure your reverse proxy (nginx, Apache) or load balancer
21//! to terminate SSL and forward to your Torch application
22//! ```
23//!
24//! ### 2. Enable Security Headers
25//! ```rust
26//! use torch_web::{App, security::SecurityHeaders};
27//!
28//! let app = App::new()
29//!     .middleware(SecurityHeaders::new())
30//!     .get("/", |_req| async { Response::ok().body("Secure!") });
31//! ```
32//!
33//! ### 3. Validate All Input
34//! ```rust
35//! use torch_web::{App, security::InputValidator};
36//!
37//! let app = App::new()
38//!     .middleware(InputValidator)
39//!     .post("/api/data", |req| async move {
40//!         // Input is automatically validated
41//!         Response::ok().body("Data processed")
42//!     });
43//! ```
44//!
45//! ### 4. Use Request Signing for APIs
46//! ```rust
47//! use torch_web::{App, security::RequestSigning};
48//!
49//! let app = App::new()
50//!     .middleware(RequestSigning::new("your-secret-key"))
51//!     .post("/api/webhook", |req| async move {
52//!         // Request signature is automatically verified
53//!         Response::ok().body("Webhook processed")
54//!     });
55//! ```
56
57use std::collections::HashSet;
58use std::net::IpAddr;
59use std::str::FromStr;
60use crate::{Request, Response, middleware::Middleware};
61
62#[cfg(feature = "security")]
63use {
64    hmac::{Hmac, Mac},
65    sha2::Sha256,
66    base64::{Engine as _, engine::general_purpose},
67    uuid::Uuid,
68};
69
70/// HMAC-based request signing middleware for API security.
71///
72/// This middleware verifies that incoming requests are signed with a shared secret,
73/// providing authentication and integrity protection for API endpoints. It's particularly
74/// useful for webhooks and server-to-server communication.
75///
76/// **Note**: This middleware requires the `security` feature to be enabled.
77///
78/// # How It Works
79///
80/// 1. The client calculates an HMAC-SHA256 signature of the request body using a shared secret
81/// 2. The signature is sent in the `X-Signature` header as `sha256=<hex-encoded-signature>`
82/// 3. The middleware recalculates the signature and compares it with the provided one
83/// 4. If signatures match, the request is allowed; otherwise, it's rejected with 401 Unauthorized
84///
85/// # Examples
86///
87/// ## Basic Usage
88///
89/// ```rust
90/// use torch_web::{App, security::RequestSigning};
91///
92/// let app = App::new()
93///     .middleware(RequestSigning::new("your-secret-key"))
94///     .post("/webhook", |req| async move {
95///         // Request signature has been verified
96///         Response::ok().body("Webhook received")
97///     });
98/// ```
99///
100/// ## Client-Side Signing (Example)
101///
102/// ```python
103/// import hmac
104/// import hashlib
105/// import requests
106///
107/// def sign_request(secret, body):
108///     signature = hmac.new(
109///         secret.encode('utf-8'),
110///         body.encode('utf-8'),
111///         hashlib.sha256
112///     ).hexdigest()
113///     return f"sha256={signature}"
114///
115/// # Send signed request
116/// body = '{"event": "user.created", "data": {...}}'
117/// signature = sign_request("your-secret-key", body)
118///
119/// response = requests.post(
120///     "https://your-api.com/webhook",
121///     data=body,
122///     headers={
123///         "Content-Type": "application/json",
124///         "X-Signature": signature
125///     }
126/// )
127/// ```
128///
129/// ## With Custom Header Name
130///
131/// ```rust
132/// use torch_web::{App, security::RequestSigning};
133///
134/// // You can customize the header name by modifying the middleware
135/// let app = App::new()
136///     .middleware(RequestSigning::new("your-secret-key"))
137///     .post("/github-webhook", |req| async move {
138///         // GitHub uses X-Hub-Signature-256 header
139///         Response::ok().body("GitHub webhook processed")
140///     });
141/// ```
142pub struct RequestSigning {
143    #[cfg(feature = "security")]
144    secret: Vec<u8>,
145    #[cfg(not(feature = "security"))]
146    _phantom: std::marker::PhantomData<()>,
147}
148
149impl RequestSigning {
150    #[cfg(feature = "security")]
151    pub fn new(secret: &str) -> Self {
152        Self {
153            secret: secret.as_bytes().to_vec(),
154        }
155    }
156
157    #[cfg(not(feature = "security"))]
158    pub fn new(_secret: &str) -> Self {
159        Self {
160            _phantom: std::marker::PhantomData,
161        }
162    }
163}
164
165impl Middleware for RequestSigning {
166    fn call(
167        &self,
168        req: Request,
169        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
170    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
171        #[cfg(feature = "security")]
172        {
173            let secret = self.secret.clone();
174            Box::pin(async move {
175                // Verify request signature
176                if let Some(signature) = req.header("X-Signature") {
177                    let body = req.body();
178                    let timestamp = req.header("X-Timestamp").unwrap_or("0");
179                    
180                    let payload = format!("{}{}", timestamp, std::str::from_utf8(body).unwrap_or(""));
181                    
182                    let mut mac = Hmac::<Sha256>::new_from_slice(&secret)
183                        .expect("HMAC can take key of any size");
184                    mac.update(payload.as_bytes());
185                    let expected = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
186                    
187                    if signature != expected {
188                        return Response::with_status(http::StatusCode::UNAUTHORIZED)
189                            .body("Invalid signature");
190                    }
191                } else {
192                    return Response::with_status(http::StatusCode::UNAUTHORIZED)
193                        .body("Missing signature");
194                }
195                
196                next(req).await
197            })
198        }
199        
200        #[cfg(not(feature = "security"))]
201        {
202            Box::pin(async move {
203                next(req).await
204            })
205        }
206    }
207}
208
209/// IP whitelist middleware
210pub struct IpWhitelist {
211    allowed_ips: HashSet<IpAddr>,
212    allowed_ranges: Vec<(IpAddr, u8)>, // IP and prefix length
213}
214
215impl IpWhitelist {
216    pub fn new() -> Self {
217        Self {
218            allowed_ips: HashSet::new(),
219            allowed_ranges: Vec::new(),
220        }
221    }
222
223    pub fn allow_ip(mut self, ip: &str) -> Self {
224        if let Ok(ip) = IpAddr::from_str(ip) {
225            self.allowed_ips.insert(ip);
226        }
227        self
228    }
229
230    pub fn allow_range(mut self, range: &str) -> Self {
231        if let Some((ip_str, prefix_str)) = range.split_once('/') {
232            if let (Ok(ip), Ok(prefix)) = (IpAddr::from_str(ip_str), prefix_str.parse::<u8>()) {
233                self.allowed_ranges.push((ip, prefix));
234            }
235        }
236        self
237    }
238
239    #[allow(dead_code)]
240    fn is_ip_allowed(&self, client_ip: IpAddr) -> bool {
241        // Check exact IP matches
242        if self.allowed_ips.contains(&client_ip) {
243            return true;
244        }
245
246        // Check IP ranges
247        for (range_ip, prefix) in &self.allowed_ranges {
248            if self.ip_in_range(client_ip, *range_ip, *prefix) {
249                return true;
250            }
251        }
252
253        false
254    }
255
256    #[allow(dead_code)]
257    fn ip_in_range(&self, ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
258        match (ip, range_ip) {
259            (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
260                let ip_bits = u32::from(ip);
261                let range_bits = u32::from(range_ip);
262                let mask = !((1u32 << (32 - prefix)) - 1);
263                (ip_bits & mask) == (range_bits & mask)
264            }
265            (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
266                let ip_bits = u128::from(ip);
267                let range_bits = u128::from(range_ip);
268                let mask = !((1u128 << (128 - prefix)) - 1);
269                (ip_bits & mask) == (range_bits & mask)
270            }
271            _ => false,
272        }
273    }
274}
275
276fn is_ip_allowed_static(
277    client_ip: IpAddr,
278    allowed_ips: &HashSet<IpAddr>,
279    allowed_ranges: &[(IpAddr, u8)]
280) -> bool {
281    // Check exact IP matches
282    if allowed_ips.contains(&client_ip) {
283        return true;
284    }
285
286    // Check IP ranges
287    for (range_ip, prefix) in allowed_ranges {
288        if ip_in_range_static(client_ip, *range_ip, *prefix) {
289            return true;
290        }
291    }
292
293    false
294}
295
296fn ip_in_range_static(ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
297    match (ip, range_ip) {
298        (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
299            let ip_bits = u32::from(ip);
300            let range_bits = u32::from(range_ip);
301            let mask = !((1u32 << (32 - prefix)) - 1);
302            (ip_bits & mask) == (range_bits & mask)
303        }
304        (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
305            let ip_bits = u128::from(ip);
306            let range_bits = u128::from(range_ip);
307            let mask = !((1u128 << (128 - prefix)) - 1);
308            (ip_bits & mask) == (range_bits & mask)
309        }
310        _ => false,
311    }
312}
313
314impl Middleware for IpWhitelist {
315    fn call(
316        &self,
317        req: Request,
318        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
319    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
320        let allowed_ips = self.allowed_ips.clone();
321        let allowed_ranges = self.allowed_ranges.clone();
322
323        Box::pin(async move {
324            // Extract client IP from headers or connection info
325            let client_ip = req.header("X-Forwarded-For")
326                .or_else(|| req.header("X-Real-IP"))
327                .and_then(|ip_str| IpAddr::from_str(ip_str).ok());
328
329            if let Some(client_ip) = client_ip {
330                if !is_ip_allowed_static(client_ip, &allowed_ips, &allowed_ranges) {
331                    return Response::with_status(http::StatusCode::FORBIDDEN)
332                        .body("IP address not allowed");
333                }
334            } else {
335                return Response::with_status(http::StatusCode::BAD_REQUEST)
336                    .body("Unable to determine client IP");
337            }
338
339            next(req).await
340        })
341    }
342}
343
344/// Request ID middleware for tracking requests
345pub struct RequestId;
346
347impl Middleware for RequestId {
348    fn call(
349        &self,
350        req: Request,
351        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
352    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
353        Box::pin(async move {
354            // Generate or extract request ID
355            let request_id = req.header("X-Request-ID")
356                .map(|id| id.to_string())
357                .unwrap_or_else(|| {
358                    #[cfg(feature = "security")]
359                    {
360                        Uuid::new_v4().to_string()
361                    }
362                    #[cfg(not(feature = "security"))]
363                    {
364                        format!("req_{}", std::time::SystemTime::now()
365                            .duration_since(std::time::UNIX_EPOCH)
366                            .unwrap_or_default()
367                            .as_millis())
368                    }
369                });
370
371            // Add request ID to request context (would need to extend Request struct)
372            // For now, we'll add it as a custom header in the response
373            
374            let mut response = next(req).await;
375            response = response.header("X-Request-ID", &request_id);
376            response
377        })
378    }
379}
380
381/// Input validation middleware
382pub struct InputValidator;
383
384impl InputValidator {
385    fn is_safe_input(input: &str) -> bool {
386        // Basic SQL injection patterns
387        let sql_patterns = [
388            "union", "select", "insert", "update", "delete", "drop", "create", "alter",
389            "exec", "execute", "sp_", "xp_", "--", "/*", "*/", ";",
390        ];
391
392        // Basic XSS patterns
393        let xss_patterns = [
394            "<script", "</script>", "javascript:", "onload=", "onerror=", "onclick=",
395            "onmouseover=", "onfocus=", "onblur=", "onchange=", "onsubmit=",
396        ];
397
398        let input_lower = input.to_lowercase();
399
400        // Check for SQL injection patterns
401        for pattern in &sql_patterns {
402            if input_lower.contains(pattern) {
403                return false;
404            }
405        }
406
407        // Check for XSS patterns
408        for pattern in &xss_patterns {
409            if input_lower.contains(pattern) {
410                return false;
411            }
412        }
413
414        // Check for path traversal
415        if input.contains("../") || input.contains("..\\") {
416            return false;
417        }
418
419        // Check for null bytes
420        if input.contains('\0') {
421            return false;
422        }
423
424        true
425    }
426
427    fn validate_request_data(req: &Request) -> Result<(), String> {
428        // Validate query parameters
429        for (key, value) in req.query_params() {
430            if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
431                return Err(format!("Invalid query parameter: {}", key));
432            }
433        }
434
435        // Validate path parameters
436        for (key, value) in req.params() {
437            if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
438                return Err(format!("Invalid path parameter: {}", key));
439            }
440        }
441
442        // Validate request body if it's text
443        if let Ok(body_str) = req.body_string() {
444            if !Self::is_safe_input(&body_str) {
445                return Err("Invalid request body content".to_string());
446            }
447        }
448
449        Ok(())
450    }
451}
452
453impl Middleware for InputValidator {
454    fn call(
455        &self,
456        req: Request,
457        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
458    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
459        Box::pin(async move {
460            // Validate input
461            if let Err(error) = Self::validate_request_data(&req) {
462                return Response::with_status(http::StatusCode::BAD_REQUEST)
463                    .body(format!("Input validation failed: {}", error));
464            }
465
466            next(req).await
467        })
468    }
469}
470
471/// Enhanced security headers middleware
472pub struct SecurityHeaders {
473    content_security_policy: Option<String>,
474}
475
476impl SecurityHeaders {
477    pub fn new() -> Self {
478        Self {
479            content_security_policy: Some(
480                "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' https:; connect-src 'self'; frame-ancestors 'none';"
481                    .to_string(),
482            ),
483        }
484    }
485
486    pub fn with_csp(mut self, csp: &str) -> Self {
487        self.content_security_policy = Some(csp.to_string());
488        self
489    }
490
491    pub fn without_csp(mut self) -> Self {
492        self.content_security_policy = None;
493        self
494    }
495}
496
497impl Middleware for SecurityHeaders {
498    fn call(
499        &self,
500        req: Request,
501        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
502    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
503        let csp = self.content_security_policy.clone();
504        Box::pin(async move {
505            let mut response = next(req).await;
506
507            // Add comprehensive security headers
508            response = response
509                .header("X-Content-Type-Options", "nosniff")
510                .header("X-Frame-Options", "DENY")
511                .header("X-XSS-Protection", "1; mode=block")
512                .header("Referrer-Policy", "strict-origin-when-cross-origin")
513                .header("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
514                .header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload");
515
516            if let Some(csp) = csp {
517                response = response.header("Content-Security-Policy", &csp);
518            }
519
520            response
521        })
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use std::net::Ipv4Addr;
529
530    #[test]
531    fn test_ip_whitelist() {
532        let whitelist = IpWhitelist::new()
533            .allow_ip("192.168.1.1")
534            .allow_range("10.0.0.0/8");
535
536        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
537        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
538        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
539        assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))));
540        assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
541    }
542
543    #[test]
544    fn test_input_validation() {
545        assert!(!InputValidator::is_safe_input("'; DROP TABLE users; --"));
546        assert!(!InputValidator::is_safe_input("<script>alert('xss')</script>"));
547        assert!(!InputValidator::is_safe_input("../../../etc/passwd"));
548        assert!(!InputValidator::is_safe_input("test\0null"));
549        assert!(InputValidator::is_safe_input("normal input text"));
550        assert!(InputValidator::is_safe_input("user@example.com"));
551    }
552}