torch_web/
security.rs

1//! Security middleware and utilities for Torch framework
2
3use std::collections::HashSet;
4use std::net::IpAddr;
5use std::str::FromStr;
6use crate::{Request, Response, middleware::Middleware};
7
8#[cfg(feature = "security")]
9use {
10    hmac::{Hmac, Mac},
11    sha2::Sha256,
12    base64::{Engine as _, engine::general_purpose},
13    uuid::Uuid,
14};
15
16/// Request signing middleware for API security
17pub struct RequestSigning {
18    #[cfg(feature = "security")]
19    secret: Vec<u8>,
20    #[cfg(not(feature = "security"))]
21    _phantom: std::marker::PhantomData<()>,
22}
23
24impl RequestSigning {
25    #[cfg(feature = "security")]
26    pub fn new(secret: &str) -> Self {
27        Self {
28            secret: secret.as_bytes().to_vec(),
29        }
30    }
31
32    #[cfg(not(feature = "security"))]
33    pub fn new(_secret: &str) -> Self {
34        Self {
35            _phantom: std::marker::PhantomData,
36        }
37    }
38}
39
40impl Middleware for RequestSigning {
41    fn call(
42        &self,
43        req: Request,
44        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
45    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
46        #[cfg(feature = "security")]
47        {
48            let secret = self.secret.clone();
49            Box::pin(async move {
50                // Verify request signature
51                if let Some(signature) = req.header("X-Signature") {
52                    let body = req.body();
53                    let timestamp = req.header("X-Timestamp").unwrap_or("0");
54                    
55                    let payload = format!("{}{}", timestamp, std::str::from_utf8(body).unwrap_or(""));
56                    
57                    let mut mac = Hmac::<Sha256>::new_from_slice(&secret)
58                        .expect("HMAC can take key of any size");
59                    mac.update(payload.as_bytes());
60                    let expected = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
61                    
62                    if signature != expected {
63                        return Response::with_status(http::StatusCode::UNAUTHORIZED)
64                            .body("Invalid signature");
65                    }
66                } else {
67                    return Response::with_status(http::StatusCode::UNAUTHORIZED)
68                        .body("Missing signature");
69                }
70                
71                next(req).await
72            })
73        }
74        
75        #[cfg(not(feature = "security"))]
76        {
77            Box::pin(async move {
78                next(req).await
79            })
80        }
81    }
82}
83
84/// IP whitelist middleware
85pub struct IpWhitelist {
86    allowed_ips: HashSet<IpAddr>,
87    allowed_ranges: Vec<(IpAddr, u8)>, // IP and prefix length
88}
89
90impl IpWhitelist {
91    pub fn new() -> Self {
92        Self {
93            allowed_ips: HashSet::new(),
94            allowed_ranges: Vec::new(),
95        }
96    }
97
98    pub fn allow_ip(mut self, ip: &str) -> Self {
99        if let Ok(ip) = IpAddr::from_str(ip) {
100            self.allowed_ips.insert(ip);
101        }
102        self
103    }
104
105    pub fn allow_range(mut self, range: &str) -> Self {
106        if let Some((ip_str, prefix_str)) = range.split_once('/') {
107            if let (Ok(ip), Ok(prefix)) = (IpAddr::from_str(ip_str), prefix_str.parse::<u8>()) {
108                self.allowed_ranges.push((ip, prefix));
109            }
110        }
111        self
112    }
113
114    #[allow(dead_code)]
115    fn is_ip_allowed(&self, client_ip: IpAddr) -> bool {
116        // Check exact IP matches
117        if self.allowed_ips.contains(&client_ip) {
118            return true;
119        }
120
121        // Check IP ranges
122        for (range_ip, prefix) in &self.allowed_ranges {
123            if self.ip_in_range(client_ip, *range_ip, *prefix) {
124                return true;
125            }
126        }
127
128        false
129    }
130
131    #[allow(dead_code)]
132    fn ip_in_range(&self, ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
133        match (ip, range_ip) {
134            (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
135                let ip_bits = u32::from(ip);
136                let range_bits = u32::from(range_ip);
137                let mask = !((1u32 << (32 - prefix)) - 1);
138                (ip_bits & mask) == (range_bits & mask)
139            }
140            (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
141                let ip_bits = u128::from(ip);
142                let range_bits = u128::from(range_ip);
143                let mask = !((1u128 << (128 - prefix)) - 1);
144                (ip_bits & mask) == (range_bits & mask)
145            }
146            _ => false,
147        }
148    }
149}
150
151fn is_ip_allowed_static(
152    client_ip: IpAddr,
153    allowed_ips: &HashSet<IpAddr>,
154    allowed_ranges: &[(IpAddr, u8)]
155) -> bool {
156    // Check exact IP matches
157    if allowed_ips.contains(&client_ip) {
158        return true;
159    }
160
161    // Check IP ranges
162    for (range_ip, prefix) in allowed_ranges {
163        if ip_in_range_static(client_ip, *range_ip, *prefix) {
164            return true;
165        }
166    }
167
168    false
169}
170
171fn ip_in_range_static(ip: IpAddr, range_ip: IpAddr, prefix: u8) -> bool {
172    match (ip, range_ip) {
173        (IpAddr::V4(ip), IpAddr::V4(range_ip)) => {
174            let ip_bits = u32::from(ip);
175            let range_bits = u32::from(range_ip);
176            let mask = !((1u32 << (32 - prefix)) - 1);
177            (ip_bits & mask) == (range_bits & mask)
178        }
179        (IpAddr::V6(ip), IpAddr::V6(range_ip)) => {
180            let ip_bits = u128::from(ip);
181            let range_bits = u128::from(range_ip);
182            let mask = !((1u128 << (128 - prefix)) - 1);
183            (ip_bits & mask) == (range_bits & mask)
184        }
185        _ => false,
186    }
187}
188
189impl Middleware for IpWhitelist {
190    fn call(
191        &self,
192        req: Request,
193        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
194    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
195        let allowed_ips = self.allowed_ips.clone();
196        let allowed_ranges = self.allowed_ranges.clone();
197
198        Box::pin(async move {
199            // Extract client IP from headers or connection info
200            let client_ip = req.header("X-Forwarded-For")
201                .or_else(|| req.header("X-Real-IP"))
202                .and_then(|ip_str| IpAddr::from_str(ip_str).ok());
203
204            if let Some(client_ip) = client_ip {
205                if !is_ip_allowed_static(client_ip, &allowed_ips, &allowed_ranges) {
206                    return Response::with_status(http::StatusCode::FORBIDDEN)
207                        .body("IP address not allowed");
208                }
209            } else {
210                return Response::with_status(http::StatusCode::BAD_REQUEST)
211                    .body("Unable to determine client IP");
212            }
213
214            next(req).await
215        })
216    }
217}
218
219/// Request ID middleware for tracking requests
220pub struct RequestId;
221
222impl Middleware for RequestId {
223    fn call(
224        &self,
225        req: Request,
226        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
227    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
228        Box::pin(async move {
229            // Generate or extract request ID
230            let request_id = req.header("X-Request-ID")
231                .map(|id| id.to_string())
232                .unwrap_or_else(|| {
233                    #[cfg(feature = "security")]
234                    {
235                        Uuid::new_v4().to_string()
236                    }
237                    #[cfg(not(feature = "security"))]
238                    {
239                        format!("req_{}", std::time::SystemTime::now()
240                            .duration_since(std::time::UNIX_EPOCH)
241                            .unwrap_or_default()
242                            .as_millis())
243                    }
244                });
245
246            // Add request ID to request context (would need to extend Request struct)
247            // For now, we'll add it as a custom header in the response
248            
249            let mut response = next(req).await;
250            response = response.header("X-Request-ID", &request_id);
251            response
252        })
253    }
254}
255
256/// Input validation middleware
257pub struct InputValidator;
258
259impl InputValidator {
260    fn is_safe_input(input: &str) -> bool {
261        // Basic SQL injection patterns
262        let sql_patterns = [
263            "union", "select", "insert", "update", "delete", "drop", "create", "alter",
264            "exec", "execute", "sp_", "xp_", "--", "/*", "*/", ";",
265        ];
266
267        // Basic XSS patterns
268        let xss_patterns = [
269            "<script", "</script>", "javascript:", "onload=", "onerror=", "onclick=",
270            "onmouseover=", "onfocus=", "onblur=", "onchange=", "onsubmit=",
271        ];
272
273        let input_lower = input.to_lowercase();
274
275        // Check for SQL injection patterns
276        for pattern in &sql_patterns {
277            if input_lower.contains(pattern) {
278                return false;
279            }
280        }
281
282        // Check for XSS patterns
283        for pattern in &xss_patterns {
284            if input_lower.contains(pattern) {
285                return false;
286            }
287        }
288
289        // Check for path traversal
290        if input.contains("../") || input.contains("..\\") {
291            return false;
292        }
293
294        // Check for null bytes
295        if input.contains('\0') {
296            return false;
297        }
298
299        true
300    }
301
302    fn validate_request_data(req: &Request) -> Result<(), String> {
303        // Validate query parameters
304        for (key, value) in req.query_params() {
305            if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
306                return Err(format!("Invalid query parameter: {}", key));
307            }
308        }
309
310        // Validate path parameters
311        for (key, value) in req.params() {
312            if !Self::is_safe_input(key) || !Self::is_safe_input(value) {
313                return Err(format!("Invalid path parameter: {}", key));
314            }
315        }
316
317        // Validate request body if it's text
318        if let Ok(body_str) = req.body_string() {
319            if !Self::is_safe_input(&body_str) {
320                return Err("Invalid request body content".to_string());
321            }
322        }
323
324        Ok(())
325    }
326}
327
328impl Middleware for InputValidator {
329    fn call(
330        &self,
331        req: Request,
332        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
333    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
334        Box::pin(async move {
335            // Validate input
336            if let Err(error) = Self::validate_request_data(&req) {
337                return Response::with_status(http::StatusCode::BAD_REQUEST)
338                    .body(format!("Input validation failed: {}", error));
339            }
340
341            next(req).await
342        })
343    }
344}
345
346/// Enhanced security headers middleware
347pub struct SecurityHeaders {
348    content_security_policy: Option<String>,
349}
350
351impl SecurityHeaders {
352    pub fn new() -> Self {
353        Self {
354            content_security_policy: Some(
355                "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' https:; connect-src 'self'; frame-ancestors 'none';"
356                    .to_string(),
357            ),
358        }
359    }
360
361    pub fn with_csp(mut self, csp: &str) -> Self {
362        self.content_security_policy = Some(csp.to_string());
363        self
364    }
365
366    pub fn without_csp(mut self) -> Self {
367        self.content_security_policy = None;
368        self
369    }
370}
371
372impl Middleware for SecurityHeaders {
373    fn call(
374        &self,
375        req: Request,
376        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
377    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
378        let csp = self.content_security_policy.clone();
379        Box::pin(async move {
380            let mut response = next(req).await;
381
382            // Add comprehensive security headers
383            response = response
384                .header("X-Content-Type-Options", "nosniff")
385                .header("X-Frame-Options", "DENY")
386                .header("X-XSS-Protection", "1; mode=block")
387                .header("Referrer-Policy", "strict-origin-when-cross-origin")
388                .header("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
389                .header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload");
390
391            if let Some(csp) = csp {
392                response = response.header("Content-Security-Policy", &csp);
393            }
394
395            response
396        })
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use std::net::Ipv4Addr;
404
405    #[test]
406    fn test_ip_whitelist() {
407        let whitelist = IpWhitelist::new()
408            .allow_ip("192.168.1.1")
409            .allow_range("10.0.0.0/8");
410
411        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
412        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
413        assert!(whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255))));
414        assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))));
415        assert!(!whitelist.is_ip_allowed(IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1))));
416    }
417
418    #[test]
419    fn test_input_validation() {
420        assert!(!InputValidator::is_safe_input("'; DROP TABLE users; --"));
421        assert!(!InputValidator::is_safe_input("<script>alert('xss')</script>"));
422        assert!(!InputValidator::is_safe_input("../../../etc/passwd"));
423        assert!(!InputValidator::is_safe_input("test\0null"));
424        assert!(InputValidator::is_safe_input("normal input text"));
425        assert!(InputValidator::is_safe_input("user@example.com"));
426    }
427}