Skip to main content

wae_authentication/csrf/
service.rs

1//! CSRF 服务实现
2
3use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
4use chrono::Utc;
5use std::{collections::HashMap, sync::Arc};
6use tokio::sync::RwLock;
7
8use crate::csrf::{CsrfConfig, CsrfError, CsrfResult};
9
10/// CSRF 令牌记录
11#[derive(Debug, Clone)]
12struct CsrfTokenRecord {
13    /// 令牌值(已哈希)
14    token_hash: String,
15    /// 过期时间戳
16    expires_at: i64,
17    /// 会话 ID
18    session_id: String,
19}
20
21/// CSRF 令牌信息
22#[derive(Debug, Clone)]
23pub struct CsrfToken {
24    /// 令牌值(用于传输)
25    pub token: String,
26    /// 过期时间戳
27    pub expires_at: i64,
28}
29
30/// CSRF 服务
31#[derive(Debug, Clone)]
32pub struct CsrfService {
33    config: CsrfConfig,
34    tokens: Arc<RwLock<HashMap<String, CsrfTokenRecord>>>,
35}
36
37impl CsrfService {
38    /// 创建新的 CSRF 服务
39    ///
40    /// # Arguments
41    /// * `config` - CSRF 配置
42    pub fn new(config: CsrfConfig) -> Self {
43        Self { config, tokens: Arc::new(RwLock::new(HashMap::new())) }
44    }
45
46    /// 使用默认配置创建 CSRF 服务
47    pub fn default() -> Self {
48        Self::new(CsrfConfig::default())
49    }
50
51    /// 生成 CSRF 令牌
52    ///
53    /// # Arguments
54    /// * `session_id` - 会话 ID
55    pub async fn generate_token(&self, session_id: &str) -> CsrfResult<CsrfToken> {
56        let token_bytes = uuid::Uuid::new_v4().into_bytes();
57        let token = URL_SAFE_NO_PAD.encode(&token_bytes[..self.config.token_length.min(16)]);
58        let token_hash = Self::hash_token(&token);
59        let expires_at = Utc::now().timestamp() + self.config.token_ttl as i64;
60
61        let record = CsrfTokenRecord { token_hash, expires_at, session_id: session_id.to_string() };
62
63        self.tokens.write().await.insert(session_id.to_string(), record);
64
65        self.cleanup_expired_tokens().await;
66
67        Ok(CsrfToken { token, expires_at })
68    }
69
70    /// 验证 CSRF 令牌
71    ///
72    /// # Arguments
73    /// * `session_id` - 会话 ID
74    /// * `token` - CSRF 令牌
75    pub async fn validate_token(&self, session_id: &str, token: &str) -> CsrfResult<bool> {
76        let tokens = self.tokens.read().await;
77
78        let record = tokens.get(session_id).ok_or(CsrfError::InvalidToken)?;
79
80        if record.expires_at < Utc::now().timestamp() {
81            return Err(CsrfError::TokenExpired);
82        }
83
84        let token_hash = Self::hash_token(token);
85
86        if record.token_hash != token_hash {
87            return Err(CsrfError::InvalidToken);
88        }
89
90        Ok(true)
91    }
92
93    /// 撤销 CSRF 令牌
94    ///
95    /// # Arguments
96    /// * `session_id` - 会话 ID
97    pub async fn revoke_token(&self, session_id: &str) {
98        self.tokens.write().await.remove(session_id);
99    }
100
101    /// 哈希令牌(用于存储)
102    fn hash_token(token: &str) -> String {
103        use sha2::{Digest, Sha256};
104        let mut hasher = Sha256::new();
105        hasher.update(token.as_bytes());
106        let result = hasher.finalize();
107        URL_SAFE_NO_PAD.encode(result)
108    }
109
110    /// 清理过期令牌
111    async fn cleanup_expired_tokens(&self) {
112        let now = Utc::now().timestamp();
113        let mut tokens = self.tokens.write().await;
114        tokens.retain(|_, record| record.expires_at > now);
115    }
116
117    /// 获取配置
118    pub fn config(&self) -> &CsrfConfig {
119        &self.config
120    }
121}
122
123impl Default for CsrfService {
124    fn default() -> Self {
125        Self::new(CsrfConfig::default())
126    }
127}