1use crate::core::{Middleware, Next};
4use crate::types::{Request, Response};
5use async_trait::async_trait;
6use sha1::{Digest, Sha1};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::{Duration, SystemTime};
10use uuid::Uuid;
11
12#[derive(Debug)]
14pub struct CsrfMiddleware {
15 secret_key: String,
16 token_name: String,
17 cookie_name: String,
18 header_name: String,
19 exclude_paths: Vec<String>,
20 token_store: Arc<RwLock<HashMap<String, (String, SystemTime)>>>,
21 token_lifetime: Duration,
22}
23
24impl CsrfMiddleware {
25 pub fn new(secret_key: String) -> Self {
27 Self {
28 secret_key,
29 token_name: "csrf_token".to_string(),
30 cookie_name: "csrf_token".to_string(),
31 header_name: "X-CSRF-Token".to_string(),
32 exclude_paths: vec![],
33 token_store: Arc::new(RwLock::new(HashMap::new())),
34 token_lifetime: Duration::from_secs(3600), }
36 }
37
38 pub fn token_name(mut self, name: String) -> Self {
40 self.token_name = name;
41 self
42 }
43
44 pub fn cookie_name(mut self, name: String) -> Self {
46 self.cookie_name = name;
47 self
48 }
49
50 pub fn header_name(mut self, name: String) -> Self {
52 self.header_name = name;
53 self
54 }
55
56 pub fn exclude_path(mut self, path: String) -> Self {
58 self.exclude_paths.push(path);
59 self
60 }
61
62 pub fn token_lifetime(mut self, lifetime: Duration) -> Self {
64 self.token_lifetime = lifetime;
65 self
66 }
67
68 fn generate_token(&self, session_id: &str) -> String {
70 let timestamp = SystemTime::now()
71 .duration_since(SystemTime::UNIX_EPOCH)
72 .unwrap()
73 .as_secs();
74
75 let raw_token = format!("{}:{}:{}", session_id, timestamp, self.secret_key);
76 let mut hasher = Sha1::new();
77 hasher.update(raw_token.as_bytes());
78 let hash = hasher.finalize();
79
80 format!("{}:{}", timestamp, hex::encode(hash))
81 }
82
83 fn validate_token(&self, token: &str, session_id: &str) -> bool {
85 let parts: Vec<&str> = token.split(':').collect();
86 if parts.len() != 2 {
87 return false;
88 }
89
90 let timestamp_str = parts[0];
91 let hash_str = parts[1];
92
93 if let Ok(timestamp) = timestamp_str.parse::<u64>() {
94 let token_time = SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp);
95 let now = SystemTime::now();
96
97 if now.duration_since(token_time).unwrap_or(Duration::MAX) > self.token_lifetime {
99 return false;
100 }
101
102 let raw_token = format!("{}:{}:{}", session_id, timestamp, self.secret_key);
104 let mut hasher = Sha1::new();
105 hasher.update(raw_token.as_bytes());
106 let expected_hash = hex::encode(hasher.finalize());
107
108 return hash_str == expected_hash;
109 }
110
111 false
112 }
113
114 fn cleanup_expired_tokens(&self) {
116 let mut store = self.token_store.write().unwrap();
117 let now = SystemTime::now();
118 store.retain(|_, (_, created_at)| {
119 now.duration_since(*created_at).unwrap_or(Duration::MAX) <= self.token_lifetime
120 });
121 }
122}
123
124#[async_trait]
125impl Middleware for CsrfMiddleware {
126 async fn call(&self, mut request: Request, next: Next) -> crate::Result<Response> {
127 let path = request.uri.path();
128
129 if self.exclude_paths.iter().any(|p| path.starts_with(p)) {
131 return next.run(request).await;
132 }
133
134 self.cleanup_expired_tokens();
136
137 if matches!(
139 request.method,
140 crate::types::HttpMethod::GET
141 | crate::types::HttpMethod::HEAD
142 | crate::types::HttpMethod::OPTIONS
143 ) {
144 let session_id = request
146 .cookie("session_id")
147 .map(|c| c.value.clone())
148 .unwrap_or_else(|| Uuid::new_v4().to_string());
149
150 let token = self.generate_token(&session_id);
151
152 {
154 let mut store = self.token_store.write().unwrap();
155 store.insert(session_id.clone(), (token.clone(), SystemTime::now()));
156 }
157
158 request
160 .extensions
161 .insert("csrf_token".to_string(), token.clone());
162
163 let mut response = next.run(request).await?;
164
165 response.headers.insert("X-CSRF-Token".to_string(), token);
167
168 return Ok(response);
169 }
170
171 let session_id = request
173 .cookie("session_id")
174 .map(|c| c.value.clone())
175 .unwrap_or_default();
176
177 if session_id.is_empty() {
178 return Ok(
179 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Missing session")
180 );
181 }
182
183 let token = request.headers.get(&self.header_name).cloned().or_else(|| {
185 request.form(&self.token_name).map(|s| s.to_string())
187 });
188
189 let token = match token {
190 Some(t) => t,
191 None => {
192 return Ok(
193 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Missing token")
194 );
195 }
196 };
197
198 if !self.validate_token(&token, &session_id) {
200 return Ok(
201 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Invalid token")
202 );
203 }
204
205 next.run(request).await
206 }
207}
208
209#[derive(Debug)]
211pub struct XssProtectionMiddleware {
212 enable_filtering: bool,
213 block_mode: bool,
214}
215
216impl XssProtectionMiddleware {
217 pub fn new() -> Self {
219 Self {
220 enable_filtering: true,
221 block_mode: true,
222 }
223 }
224
225 pub fn filtering(mut self, enable: bool) -> Self {
227 self.enable_filtering = enable;
228 self
229 }
230
231 pub fn block_mode(mut self, block: bool) -> Self {
233 self.block_mode = block;
234 self
235 }
236}
237
238impl Default for XssProtectionMiddleware {
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244#[async_trait]
245impl Middleware for XssProtectionMiddleware {
246 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
247 let mut response = next.run(request).await?;
248
249 if self.enable_filtering {
251 let header_value = if self.block_mode {
252 "1; mode=block"
253 } else {
254 "1"
255 };
256 response
257 .headers
258 .insert("X-XSS-Protection".to_string(), header_value.to_string());
259 } else {
260 response
261 .headers
262 .insert("X-XSS-Protection".to_string(), "0".to_string());
263 }
264
265 response
267 .headers
268 .insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
269
270 response
272 .headers
273 .insert("X-Frame-Options".to_string(), "DENY".to_string());
274
275 Ok(response)
276 }
277}
278
279#[derive(Debug)]
281pub struct CspMiddleware {
282 directives: HashMap<String, Vec<String>>,
283 report_only: bool,
284}
285
286impl CspMiddleware {
287 pub fn new() -> Self {
289 Self {
290 directives: HashMap::new(),
291 report_only: false,
292 }
293 }
294
295 pub fn default_policy() -> Self {
297 let mut csp = Self::new();
298 csp.directive("default-src", vec!["'self'".to_string()]);
299 csp.directive(
300 "script-src",
301 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
302 );
303 csp.directive(
304 "style-src",
305 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
306 );
307 csp.directive("img-src", vec!["'self'".to_string(), "data:".to_string()]);
308 csp.directive("font-src", vec!["'self'".to_string()]);
309 csp.directive("connect-src", vec!["'self'".to_string()]);
310 csp.directive("frame-ancestors", vec!["'none'".to_string()]);
311 csp
312 }
313
314 pub fn directive(&mut self, name: &str, values: Vec<String>) -> &mut Self {
316 self.directives.insert(name.to_string(), values);
317 self
318 }
319
320 pub fn report_only(mut self, report_only: bool) -> Self {
322 self.report_only = report_only;
323 self
324 }
325
326 fn build_header_value(&self) -> String {
328 self.directives
329 .iter()
330 .map(|(directive, values)| format!("{} {}", directive, values.join(" ")))
331 .collect::<Vec<_>>()
332 .join("; ")
333 }
334}
335
336impl Default for CspMiddleware {
337 fn default() -> Self {
338 Self::default_policy()
339 }
340}
341
342#[async_trait]
343impl Middleware for CspMiddleware {
344 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
345 let mut response = next.run(request).await?;
346
347 let header_name = if self.report_only {
348 "Content-Security-Policy-Report-Only"
349 } else {
350 "Content-Security-Policy"
351 };
352
353 let header_value = self.build_header_value();
354 response
355 .headers
356 .insert(header_name.to_string(), header_value);
357
358 Ok(response)
359 }
360}
361
362pub mod sanitize {
364 pub fn html(input: &str) -> String {
366 input
367 .replace('&', "&")
368 .replace('<', "<")
369 .replace('>', ">")
370 .replace('"', """)
371 .replace('\'', "'")
372 .replace('/', "/")
373 }
374
375 pub fn sql(input: &str) -> String {
377 input
378 .replace('\'', "''")
379 .replace('"', "\"\"")
380 .replace('\\', "\\\\")
381 .replace('\0', "")
382 }
383
384 pub fn filename(input: &str) -> String {
386 input
387 .chars()
388 .filter(|c: &char| c.is_alphanumeric() || *c == '.' || *c == '_' || *c == '-')
389 .collect()
390 }
391
392 pub fn is_valid_email(email: &str) -> bool {
394 email.contains('@') && email.len() > 3 && email.len() < 255
395 }
396
397 pub fn is_valid_url(url: &str) -> bool {
399 url.starts_with("http://") || url.starts_with("https://")
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_html_sanitization() {
409 let input = "<script>alert('xss')</script>";
410 let expected = "<script>alert('xss')</script>";
411 assert_eq!(sanitize::html(input), expected);
412 }
413
414 #[test]
415 fn test_filename_sanitization() {
416 let input = "../../etc/passwd";
417 let expected = "....etcpasswd";
418 assert_eq!(sanitize::filename(input), expected);
419 }
420
421 #[test]
422 fn test_email_validation() {
423 assert!(sanitize::is_valid_email("test@example.com"));
424 assert!(!sanitize::is_valid_email("invalid"));
425 assert!(sanitize::is_valid_email("@example.com")); }
427
428 #[test]
429 fn test_url_validation() {
430 assert!(sanitize::is_valid_url("https://example.com"));
431 assert!(sanitize::is_valid_url("http://example.com"));
432 assert!(!sanitize::is_valid_url("ftp://example.com"));
433 assert!(!sanitize::is_valid_url("example.com"));
434 }
435
436 #[tokio::test]
437 async fn test_csrf_token_generation() {
438 let middleware = CsrfMiddleware::new("secret_key".to_string());
439 let token = middleware.generate_token("session_123");
440 assert!(!token.is_empty());
441 assert!(token.contains(':'));
442 }
443
444 #[test]
445 fn test_csp_header_building() {
446 let mut csp = CspMiddleware::new();
447 csp.directive("default-src", vec!["'self'".to_string()]);
448 csp.directive(
449 "script-src",
450 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
451 );
452
453 let header = csp.build_header_value();
454 assert!(header.contains("default-src 'self'"));
455 assert!(header.contains("script-src 'self' 'unsafe-inline'"));
456 }
457}