sa_token_core/
refresh.rs

1// Author: 金书记
2//
3//! Refresh Token Module | Refresh Token 模块
4//!
5//! Implements token refresh mechanism for long-term authentication
6//! 实现长期认证的 Token 刷新机制
7
8use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::TokenValue;
13use crate::token::TokenGenerator;
14use crate::config::SaTokenConfig;
15use uuid::Uuid;
16
17/// Refresh Token Manager | Refresh Token 管理器
18///
19/// Manages refresh token generation, validation, and access token renewal
20/// 管理 refresh token 的生成、验证和访问令牌的更新
21#[derive(Clone)]
22pub struct RefreshTokenManager {
23    storage: Arc<dyn SaStorage>,
24    config: Arc<SaTokenConfig>,
25}
26
27impl RefreshTokenManager {
28    /// Create new refresh token manager | 创建新的 refresh token 管理器
29    ///
30    /// # Arguments | 参数
31    ///
32    /// * `storage` - Storage backend | 存储后端
33    /// * `config` - Sa-token configuration | Sa-token 配置
34    pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
35        Self { storage, config }
36    }
37
38    /// Generate a new refresh token | 生成新的 refresh token
39    ///
40    /// # Arguments | 参数
41    ///
42    /// * `login_id` - User login ID | 用户登录ID
43    ///
44    /// # Returns | 返回
45    ///
46    /// Refresh token string | Refresh token 字符串
47    pub fn generate(&self, login_id: &str) -> String {
48        // Format: refresh_TIMESTAMP_LOGINID_UUID
49        format!(
50            "refresh_{}_{}_{}",
51            Utc::now().timestamp_millis(),
52            login_id,
53            Uuid::new_v4().simple()
54        )
55    }
56
57    /// Store refresh token with associated access token | 存储 refresh token 及其关联的访问令牌
58    ///
59    /// # Arguments | 参数
60    ///
61    /// * `refresh_token` - Refresh token | Refresh token
62    /// * `access_token` - Associated access token | 关联的访问令牌
63    /// * `login_id` - User login ID | 用户登录ID
64    pub async fn store(
65        &self,
66        refresh_token: &str,
67        access_token: &str,
68        login_id: &str,
69    ) -> SaTokenResult<()> {
70        let key = format!("sa:refresh:{}", refresh_token);
71        let expire_time = if self.config.refresh_token_timeout > 0 {
72            Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
73        } else {
74            None
75        };
76
77        let value = serde_json::json!({
78            "access_token": access_token,
79            "login_id": login_id,
80            "created_at": Utc::now().to_rfc3339(),
81            "expire_time": expire_time.map(|t| t.to_rfc3339()),
82        }).to_string();
83
84        let ttl = if self.config.refresh_token_timeout > 0 {
85            Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
86        } else {
87            None
88        };
89
90        self.storage.set(&key, &value, ttl)
91            .await
92            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
93
94        Ok(())
95    }
96
97    /// Validate refresh token | 验证 refresh token
98    ///
99    /// # Arguments | 参数
100    ///
101    /// * `refresh_token` - Refresh token to validate | 要验证的 refresh token
102    ///
103    /// # Returns | 返回
104    ///
105    /// Associated login_id if valid | 如果有效则返回关联的 login_id
106    pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
107        let key = format!("sa:refresh:{}", refresh_token);
108        
109        let value_str = self.storage.get(&key)
110            .await
111            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
112            .ok_or_else(|| SaTokenError::RefreshTokenNotFound)?;
113
114        let value: serde_json::Value = serde_json::from_str(&value_str)
115            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
116
117        let login_id = value["login_id"].as_str()
118            .ok_or_else(|| SaTokenError::RefreshTokenMissingLoginId)?
119            .to_string();
120
121        // Check expiration if set
122        if let Some(expire_str) = value["expire_time"].as_str() {
123            let expire_time = DateTime::parse_from_rfc3339(expire_str)
124                .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
125                .with_timezone(&Utc);
126
127            if Utc::now() > expire_time {
128                // Delete expired refresh token
129                self.delete(refresh_token).await?;
130                return Err(SaTokenError::TokenExpired);
131            }
132        }
133
134        Ok(login_id)
135    }
136
137    /// Refresh access token using refresh token | 使用 refresh token 刷新访问令牌
138    ///
139    /// # Arguments | 参数
140    ///
141    /// * `refresh_token` - Refresh token | Refresh token
142    ///
143    /// # Returns | 返回
144    ///
145    /// New access token and login_id | 新的访问令牌和 login_id
146    pub async fn refresh_access_token(
147        &self,
148        refresh_token: &str,
149    ) -> SaTokenResult<(TokenValue, String)> {
150        // Validate refresh token
151        let login_id = self.validate(refresh_token).await?;
152
153        // Generate new access token
154        let new_access_token = TokenGenerator::generate_with_login_id(&self.config, &login_id);
155
156        // Update stored refresh token with new access token
157        let key = format!("sa:refresh:{}", refresh_token);
158        let value_str = self.storage.get(&key)
159            .await
160            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
161            .ok_or_else(|| SaTokenError::RefreshTokenNotFound)?;
162
163        let mut value: serde_json::Value = serde_json::from_str(&value_str)
164            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
165
166        value["access_token"] = serde_json::json!(new_access_token.as_str());
167        value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
168
169        let ttl = if self.config.refresh_token_timeout > 0 {
170            Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
171        } else {
172            None
173        };
174
175        self.storage.set(&key, &value.to_string(), ttl)
176            .await
177            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
178
179        Ok((new_access_token, login_id))
180    }
181
182    /// Delete refresh token | 删除 refresh token
183    ///
184    /// # Arguments | 参数
185    ///
186    /// * `refresh_token` - Refresh token to delete | 要删除的 refresh token
187    pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
188        let key = format!("sa:refresh:{}", refresh_token);
189        self.storage.delete(&key)
190            .await
191            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
192        Ok(())
193    }
194
195    /// Get all refresh tokens for a user | 获取用户的所有 refresh token
196    ///
197    /// Note: This requires storage backend to support prefix scanning
198    /// 注意:这需要存储后端支持前缀扫描
199    pub async fn get_user_refresh_tokens(&self, _login_id: &str) -> SaTokenResult<Vec<String>> {
200        // This is a placeholder - actual implementation depends on storage capabilities
201        // 这是一个占位符 - 实际实现取决于存储能力
202        // Most implementations would need to maintain a separate index
203        // 大多数实现需要维护一个单独的索引
204        Ok(vec![])
205    }
206
207    /// Revoke all refresh tokens for a user | 撤销用户的所有 refresh token
208    ///
209    /// # Arguments | 参数
210    ///
211    /// * `login_id` - User login ID | 用户登录ID
212    pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
213        let tokens = self.get_user_refresh_tokens(login_id).await?;
214        for token in tokens {
215            self.delete(&token).await?;
216        }
217        Ok(())
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use sa_token_storage_memory::MemoryStorage;
225    use crate::config::TokenStyle;
226
227    fn create_test_config() -> Arc<SaTokenConfig> {
228        Arc::new(SaTokenConfig {
229            token_style: TokenStyle::Uuid,
230            timeout: 3600,
231            refresh_token_timeout: 7200,
232            enable_refresh_token: true,
233            ..Default::default()
234        })
235    }
236
237    #[tokio::test]
238    async fn test_refresh_token_generation() {
239        let storage = Arc::new(MemoryStorage::new());
240        let config = create_test_config();
241        let refresh_mgr = RefreshTokenManager::new(storage, config);
242
243        let token1 = refresh_mgr.generate("user_123");
244        let token2 = refresh_mgr.generate("user_123");
245
246        assert_ne!(token1, token2);
247        assert!(token1.starts_with("refresh_"));
248    }
249
250    #[tokio::test]
251    async fn test_refresh_token_store_and_validate() {
252        let storage = Arc::new(MemoryStorage::new());
253        let config = create_test_config();
254        let refresh_mgr = RefreshTokenManager::new(storage, config);
255
256        let refresh_token = refresh_mgr.generate("user_123");
257        let access_token = "access_token_123";
258
259        // Store refresh token
260        refresh_mgr.store(&refresh_token, access_token, "user_123").await.unwrap();
261
262        // Validate refresh token
263        let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
264        assert_eq!(login_id, "user_123");
265    }
266
267    #[tokio::test]
268    async fn test_refresh_access_token() {
269        let storage = Arc::new(MemoryStorage::new());
270        let config = create_test_config();
271        let refresh_mgr = RefreshTokenManager::new(storage, config);
272
273        let refresh_token = refresh_mgr.generate("user_123");
274        let old_access_token = "old_access_token";
275
276        // Store refresh token
277        refresh_mgr.store(&refresh_token, old_access_token, "user_123").await.unwrap();
278
279        // Refresh access token
280        let (new_access_token, login_id) = refresh_mgr.refresh_access_token(&refresh_token).await.unwrap();
281
282        assert_eq!(login_id, "user_123");
283        assert_ne!(new_access_token.as_str(), old_access_token);
284    }
285
286    #[tokio::test]
287    async fn test_delete_refresh_token() {
288        let storage = Arc::new(MemoryStorage::new());
289        let config = create_test_config();
290        let refresh_mgr = RefreshTokenManager::new(storage, config);
291
292        let refresh_token = refresh_mgr.generate("user_123");
293        refresh_mgr.store(&refresh_token, "access", "user_123").await.unwrap();
294
295        // Delete refresh token
296        refresh_mgr.delete(&refresh_token).await.unwrap();
297
298        // Validation should fail
299        let result = refresh_mgr.validate(&refresh_token).await;
300        assert!(result.is_err());
301    }
302}
303