wae_authentication/csrf/
service.rs1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
4use chrono::Utc;
5use std::{collections::HashMap, sync::Arc};
6use tokio::sync::RwLock;
7
8use crate::csrf::{CsrfConfig, CsrfError, CsrfResult};
9
10#[derive(Debug, Clone)]
12struct CsrfTokenRecord {
13 token_hash: String,
15 expires_at: i64,
17 session_id: String,
19}
20
21#[derive(Debug, Clone)]
23pub struct CsrfToken {
24 pub token: String,
26 pub expires_at: i64,
28}
29
30#[derive(Debug, Clone)]
32pub struct CsrfService {
33 config: CsrfConfig,
34 tokens: Arc<RwLock<HashMap<String, CsrfTokenRecord>>>,
35}
36
37impl CsrfService {
38 pub fn new(config: CsrfConfig) -> Self {
43 Self { config, tokens: Arc::new(RwLock::new(HashMap::new())) }
44 }
45
46 pub fn default() -> Self {
48 Self::new(CsrfConfig::default())
49 }
50
51 pub async fn generate_token(&self, session_id: &str) -> CsrfResult<CsrfToken> {
56 let token_bytes = uuid::Uuid::new_v4().into_bytes();
57 let token = URL_SAFE_NO_PAD.encode(&token_bytes[..self.config.token_length.min(16)]);
58 let token_hash = Self::hash_token(&token);
59 let expires_at = Utc::now().timestamp() + self.config.token_ttl as i64;
60
61 let record = CsrfTokenRecord { token_hash, expires_at, session_id: session_id.to_string() };
62
63 self.tokens.write().await.insert(session_id.to_string(), record);
64
65 self.cleanup_expired_tokens().await;
66
67 Ok(CsrfToken { token, expires_at })
68 }
69
70 pub async fn validate_token(&self, session_id: &str, token: &str) -> CsrfResult<bool> {
76 let tokens = self.tokens.read().await;
77
78 let record = tokens.get(session_id).ok_or(CsrfError::InvalidToken)?;
79
80 if record.expires_at < Utc::now().timestamp() {
81 return Err(CsrfError::TokenExpired);
82 }
83
84 let token_hash = Self::hash_token(token);
85
86 if record.token_hash != token_hash {
87 return Err(CsrfError::InvalidToken);
88 }
89
90 Ok(true)
91 }
92
93 pub async fn revoke_token(&self, session_id: &str) {
98 self.tokens.write().await.remove(session_id);
99 }
100
101 fn hash_token(token: &str) -> String {
103 use sha2::{Digest, Sha256};
104 let mut hasher = Sha256::new();
105 hasher.update(token.as_bytes());
106 let result = hasher.finalize();
107 URL_SAFE_NO_PAD.encode(result)
108 }
109
110 async fn cleanup_expired_tokens(&self) {
112 let now = Utc::now().timestamp();
113 let mut tokens = self.tokens.write().await;
114 tokens.retain(|_, record| record.expires_at > now);
115 }
116
117 pub fn config(&self) -> &CsrfConfig {
119 &self.config
120 }
121}
122
123impl Default for CsrfService {
124 fn default() -> Self {
125 Self::new(CsrfConfig::default())
126 }
127}