Skip to main content

sa_token_core/
refresh.rs

1// Author: 金书记
2//
3//! Refresh Token Module | Refresh Token 模块
4//!
5//! Implements token refresh mechanism for long-term authentication
6//! 实现长期认证的 Token 刷新机制
7
8use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::TokenValue;
13use crate::token::TokenGenerator;
14use crate::config::SaTokenConfig;
15use uuid::Uuid;
16
17/// Refresh Token Manager | Refresh Token 管理器
18///
19/// Manages refresh token generation, validation, and access token renewal
20/// 管理 refresh token 的生成、验证和访问令牌的更新
21#[derive(Clone)]
22pub struct RefreshTokenManager {
23    storage: Arc<dyn SaStorage>,
24    config: Arc<SaTokenConfig>,
25}
26
27impl RefreshTokenManager {
28    /// Create new refresh token manager | 创建新的 refresh token 管理器
29    ///
30    /// # Arguments | 参数
31    ///
32    /// * `storage` - Storage backend | 存储后端
33    /// * `config` - Sa-token configuration | Sa-token 配置
34    pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
35        Self { storage, config }
36    }
37
38    /// Generate a new refresh token | 生成新的 refresh token
39    ///
40    /// # Arguments | 参数
41    ///
42    /// * `login_id` - User login ID | 用户登录ID
43    ///
44    /// # Returns | 返回
45    ///
46    /// Refresh token string | Refresh token 字符串
47    pub fn generate(&self, login_id: &str) -> String {
48        // Format: refresh_TIMESTAMP_LOGINID_UUID
49        format!(
50            "refresh_{}_{}_{}",
51            Utc::now().timestamp_millis(),
52            login_id,
53            Uuid::new_v4().simple()
54        )
55    }
56
57    /// Store refresh token with associated access token | 存储 refresh token 及其关联的访问令牌
58    ///
59    /// # Arguments | 参数
60    ///
61    /// * `refresh_token` - Refresh token | Refresh token
62    /// * `access_token` - Associated access token | 关联的访问令牌
63    /// * `login_id` - User login ID | 用户登录ID
64    /// * `extra_data` - Extra data from JWT claims (tenant_id, username, etc.)
65    pub async fn store(
66        &self,
67        refresh_token: &str,
68        access_token: &str,
69        login_id: &str,
70    ) -> SaTokenResult<()> {
71        self.store_with_extra(refresh_token, access_token, login_id, None).await
72    }
73
74    /// Store refresh token with associated access token and extra data
75    ///
76    /// # Arguments | 参数
77    ///
78    /// * `refresh_token` - Refresh token | Refresh token
79    /// * `access_token` - Associated access token | 关联的访问令牌
80    /// * `login_id` - User login ID | 用户登录ID
81    /// * `extra_data` - Optional extra data to preserve across refresh
82    pub async fn store_with_extra(
83        &self,
84        refresh_token: &str,
85        access_token: &str,
86        login_id: &str,
87        extra_data: Option<&serde_json::Value>,
88    ) -> SaTokenResult<()> {
89        let key = format!("sa:refresh:{}", refresh_token);
90        let expire_time = if self.config.refresh_token_timeout > 0 {
91            Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
92        } else {
93            None
94        };
95
96        let mut obj = serde_json::json!({
97            "access_token": access_token,
98            "login_id": login_id,
99            "created_at": Utc::now().to_rfc3339(),
100            "expire_time": expire_time.map(|t| t.to_rfc3339()),
101        });
102        if let Some(extra) = extra_data {
103            obj["extra_data"] = extra.clone();
104        }
105        let value = obj.to_string();
106
107        let ttl = if self.config.refresh_token_timeout > 0 {
108            Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
109        } else {
110            None
111        };
112
113        self.storage.set(&key, &value, ttl)
114            .await
115            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
116
117        Ok(())
118    }
119
120    /// Validate refresh token | 验证 refresh token
121    ///
122    /// # Arguments | 参数
123    ///
124    /// * `refresh_token` - Refresh token to validate | 要验证的 refresh token
125    ///
126    /// # Returns | 返回
127    ///
128    /// Associated login_id if valid | 如果有效则返回关联的 login_id
129    pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
130        let key = format!("sa:refresh:{}", refresh_token);
131        
132        let value_str = self.storage.get(&key)
133            .await
134            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
135            .ok_or(SaTokenError::RefreshTokenNotFound)?;
136
137        let value: serde_json::Value = serde_json::from_str(&value_str)
138            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
139
140        let login_id = value["login_id"].as_str()
141            .ok_or(SaTokenError::RefreshTokenMissingLoginId)?
142            .to_string();
143
144        // Check expiration if set
145        if let Some(expire_str) = value["expire_time"].as_str() {
146            let expire_time = DateTime::parse_from_rfc3339(expire_str)
147                .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
148                .with_timezone(&Utc);
149
150            if Utc::now() > expire_time {
151                // Delete expired refresh token
152                self.delete(refresh_token).await?;
153                return Err(SaTokenError::TokenExpired);
154            }
155        }
156
157        Ok(login_id)
158    }
159
160    /// Refresh access token using refresh token | 使用 refresh token 刷新访问令牌
161    ///
162    /// # Arguments | 参数
163    ///
164    /// * `refresh_token` - Refresh token | Refresh token
165    ///
166    /// # Returns | 返回
167    ///
168    /// New access token and login_id | 新的访问令牌和 login_id
169    pub async fn refresh_access_token(
170        &self,
171        refresh_token: &str,
172    ) -> SaTokenResult<(TokenValue, String)> {
173        // Validate refresh token
174        let login_id = self.validate(refresh_token).await?;
175
176        // Read stored refresh token data (contains extra_data)
177        let key = format!("sa:refresh:{}", refresh_token);
178        let value_str = self.storage.get(&key)
179            .await
180            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
181            .ok_or(SaTokenError::RefreshTokenNotFound)?;
182
183        let mut value: serde_json::Value = serde_json::from_str(&value_str)
184            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
185
186        // Generate new access token (with extra_data if present)
187        let extra_data = value.get("extra_data").cloned();
188        let new_access_token = match &extra_data {
189            Some(extra) => TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra),
190            None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
191        };
192
193        // Update stored refresh token with new access token
194        value["access_token"] = serde_json::json!(new_access_token.as_str());
195        value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
196
197        let ttl = if self.config.refresh_token_timeout > 0 {
198            Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
199        } else {
200            None
201        };
202
203        self.storage.set(&key, &value.to_string(), ttl)
204            .await
205            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
206
207        Ok((new_access_token, login_id))
208    }
209
210    /// Delete refresh token | 删除 refresh token
211    ///
212    /// # Arguments | 参数
213    ///
214    /// * `refresh_token` - Refresh token to delete | 要删除的 refresh token
215    pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
216        let key = format!("sa:refresh:{}", refresh_token);
217        self.storage.delete(&key)
218            .await
219            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
220        Ok(())
221    }
222
223    /// Get all refresh tokens for a user | 获取用户的所有 refresh token
224    ///
225    /// Note: This requires storage backend to support prefix scanning
226    /// 注意:这需要存储后端支持前缀扫描
227    pub async fn get_user_refresh_tokens(&self, _login_id: &str) -> SaTokenResult<Vec<String>> {
228        // This is a placeholder - actual implementation depends on storage capabilities
229        // 这是一个占位符 - 实际实现取决于存储能力
230        // Most implementations would need to maintain a separate index
231        // 大多数实现需要维护一个单独的索引
232        Ok(vec![])
233    }
234
235    /// Revoke all refresh tokens for a user | 撤销用户的所有 refresh token
236    ///
237    /// # Arguments | 参数
238    ///
239    /// * `login_id` - User login ID | 用户登录ID
240    pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
241        let tokens = self.get_user_refresh_tokens(login_id).await?;
242        for token in tokens {
243            self.delete(&token).await?;
244        }
245        Ok(())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use sa_token_storage_memory::MemoryStorage;
253    use crate::config::TokenStyle;
254
255    fn create_test_config() -> Arc<SaTokenConfig> {
256        Arc::new(SaTokenConfig {
257            token_style: TokenStyle::Uuid,
258            timeout: 3600,
259            refresh_token_timeout: 7200,
260            enable_refresh_token: true,
261            ..Default::default()
262        })
263    }
264
265    #[tokio::test]
266    async fn test_refresh_token_generation() {
267        let storage = Arc::new(MemoryStorage::new());
268        let config = create_test_config();
269        let refresh_mgr = RefreshTokenManager::new(storage, config);
270
271        let token1 = refresh_mgr.generate("user_123");
272        let token2 = refresh_mgr.generate("user_123");
273
274        assert_ne!(token1, token2);
275        assert!(token1.starts_with("refresh_"));
276    }
277
278    #[tokio::test]
279    async fn test_refresh_token_store_and_validate() {
280        let storage = Arc::new(MemoryStorage::new());
281        let config = create_test_config();
282        let refresh_mgr = RefreshTokenManager::new(storage, config);
283
284        let refresh_token = refresh_mgr.generate("user_123");
285        let access_token = "access_token_123";
286
287        // Store refresh token
288        refresh_mgr.store(&refresh_token, access_token, "user_123").await.unwrap();
289
290        // Validate refresh token
291        let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
292        assert_eq!(login_id, "user_123");
293    }
294
295    #[tokio::test]
296    async fn test_refresh_access_token() {
297        let storage = Arc::new(MemoryStorage::new());
298        let config = create_test_config();
299        let refresh_mgr = RefreshTokenManager::new(storage, config);
300
301        let refresh_token = refresh_mgr.generate("user_123");
302        let old_access_token = "old_access_token";
303
304        // Store refresh token
305        refresh_mgr.store(&refresh_token, old_access_token, "user_123").await.unwrap();
306
307        // Refresh access token
308        let (new_access_token, login_id) = refresh_mgr.refresh_access_token(&refresh_token).await.unwrap();
309
310        assert_eq!(login_id, "user_123");
311        assert_ne!(new_access_token.as_str(), old_access_token);
312    }
313
314    #[tokio::test]
315    async fn test_delete_refresh_token() {
316        let storage = Arc::new(MemoryStorage::new());
317        let config = create_test_config();
318        let refresh_mgr = RefreshTokenManager::new(storage, config);
319
320        let refresh_token = refresh_mgr.generate("user_123");
321        refresh_mgr.store(&refresh_token, "access", "user_123").await.unwrap();
322
323        // Delete refresh token
324        refresh_mgr.delete(&refresh_token).await.unwrap();
325
326        // Validation should fail
327        let result = refresh_mgr.validate(&refresh_token).await;
328        assert!(result.is_err());
329    }
330}
331