Skip to main content

sa_token_core/
refresh.rs

1// Author: 金书记
2//
3//! Refresh Token Module | Refresh Token 模块
4//!
5//! Implements token refresh mechanism for long-term authentication
6//! 实现长期认证的 Token 刷新机制
7
8use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::config::SaTokenConfig;
14use uuid::Uuid;
15
16/// Refresh Token Manager | Refresh Token 管理器
17///
18/// Manages refresh token generation, validation, and access token renewal
19/// 管理 refresh token 的生成、验证和访问令牌的更新
20#[derive(Clone)]
21pub struct RefreshTokenManager {
22    storage: Arc<dyn SaStorage>,
23    config: Arc<SaTokenConfig>,
24}
25
26impl RefreshTokenManager {
27    /// Create new refresh token manager | 创建新的 refresh token 管理器
28    pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
29        Self { storage, config }
30    }
31
32    /// refresh token 存储键:{prefix}refresh:{token}
33    fn refresh_key(&self, refresh_token: &str) -> String {
34        self.config.make_key("refresh:", refresh_token)
35    }
36
37    /// 用户 refresh token 索引键:{prefix}refresh:user:{login_id}
38    fn user_index_key(&self, login_id: &str) -> String {
39        self.config.make_key("refresh:user:", login_id)
40    }
41
42    async fn load_string_list(&self, key: &str) -> SaTokenResult<Vec<String>> {
43        match self
44            .storage
45            .get(key)
46            .await
47            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
48        {
49            Some(value) => serde_json::from_str(&value).map_err(SaTokenError::SerializationError),
50            None => Ok(Vec::new()),
51        }
52    }
53
54    async fn save_string_list(&self, key: &str, list: &[String]) -> SaTokenResult<()> {
55        let value = serde_json::to_string(list).map_err(SaTokenError::SerializationError)?;
56        self.storage
57            .set(key, &value, None)
58            .await
59            .map_err(|e| SaTokenError::StorageError(e.to_string()))
60    }
61
62    /// 将 refresh token 追加到用户索引(去重)
63    async fn append_user_index(&self, login_id: &str, refresh_token: &str) -> SaTokenResult<()> {
64        let key = self.user_index_key(login_id);
65        let mut list = self.load_string_list(&key).await?;
66        if !list.iter().any(|t| t == refresh_token) {
67            list.push(refresh_token.to_string());
68            self.save_string_list(&key, &list).await?;
69        }
70        Ok(())
71    }
72
73    /// 从用户索引移除 refresh token
74    async fn remove_user_index(&self, login_id: &str, refresh_token: &str) -> SaTokenResult<()> {
75        let key = self.user_index_key(login_id);
76        let mut list = self.load_string_list(&key).await?;
77        let before = list.len();
78        list.retain(|t| t != refresh_token);
79        if list.len() != before {
80            self.save_string_list(&key, &list).await?;
81        }
82        Ok(())
83    }
84
85    /// Generate a new refresh token | 生成新的 refresh token
86    pub fn generate(&self, login_id: &str) -> String {
87        format!(
88            "refresh_{}_{}_{}",
89            Utc::now().timestamp_millis(),
90            login_id,
91            Uuid::new_v4().simple()
92        )
93    }
94
95    /// Store refresh token with associated access token | 存储 refresh token 及其关联的访问令牌
96    pub async fn store(
97        &self,
98        refresh_token: &str,
99        access_token: &str,
100        login_id: &str,
101    ) -> SaTokenResult<()> {
102        self.store_with_extra(refresh_token, access_token, login_id, None)
103            .await
104    }
105
106    /// Store refresh token with associated access token and extra data
107    pub async fn store_with_extra(
108        &self,
109        refresh_token: &str,
110        access_token: &str,
111        login_id: &str,
112        extra_data: Option<&serde_json::Value>,
113    ) -> SaTokenResult<()> {
114        let key = self.refresh_key(refresh_token);
115        let expire_time = if self.config.refresh_token_timeout > 0 {
116            Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
117        } else {
118            None
119        };
120
121        let mut obj = serde_json::json!({
122            "access_token": access_token,
123            "login_id": login_id,
124            "created_at": Utc::now().to_rfc3339(),
125            "expire_time": expire_time.map(|t| t.to_rfc3339()),
126        });
127        if let Some(extra) = extra_data {
128            obj["extra_data"] = extra.clone();
129        }
130        let value = obj.to_string();
131
132        let ttl = if self.config.refresh_token_timeout > 0 {
133            Some(std::time::Duration::from_secs(
134                self.config.refresh_token_timeout as u64,
135            ))
136        } else {
137            None
138        };
139
140        self.storage
141            .set(&key, &value, ttl)
142            .await
143            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
144
145        self.append_user_index(login_id, refresh_token).await?;
146        Ok(())
147    }
148
149    /// Validate refresh token | 验证 refresh token
150    pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
151        let key = self.refresh_key(refresh_token);
152
153        let value_str = self
154            .storage
155            .get(&key)
156            .await
157            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
158            .ok_or(SaTokenError::RefreshTokenNotFound)?;
159
160        let value: serde_json::Value = serde_json::from_str(&value_str)
161            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
162
163        let login_id = value["login_id"]
164            .as_str()
165            .ok_or(SaTokenError::RefreshTokenMissingLoginId)?
166            .to_string();
167
168        if let Some(expire_str) = value["expire_time"].as_str() {
169            let expire_time = DateTime::parse_from_rfc3339(expire_str)
170                .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
171                .with_timezone(&Utc);
172
173            if Utc::now() > expire_time {
174                self.delete(refresh_token).await?;
175                return Err(SaTokenError::TokenExpired);
176            }
177        }
178
179        Ok(login_id)
180    }
181
182    /// Refresh access token using refresh token | 使用 refresh token 刷新访问令牌
183    ///
184    /// 生成新 access token 并回写 `{prefix}token:{token}` 存储,与 SaTokenManager 登录态对齐。
185    pub async fn refresh_access_token(
186        &self,
187        refresh_token: &str,
188    ) -> SaTokenResult<(TokenValue, String)> {
189        let login_id = self.validate(refresh_token).await?;
190
191        let key = self.refresh_key(refresh_token);
192        let value_str = self
193            .storage
194            .get(&key)
195            .await
196            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
197            .ok_or(SaTokenError::RefreshTokenNotFound)?;
198
199        let mut value: serde_json::Value = serde_json::from_str(&value_str)
200            .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
201
202        let extra_data = value.get("extra_data").cloned();
203        let new_access_token = match &extra_data {
204            Some(extra) => {
205                TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra)
206            }
207            None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
208        };
209
210        // 构造并写入新的 TokenInfo(与 Manager 登录路径一致的存储键)
211        let mut token_info = TokenInfo::new(new_access_token.clone(), login_id.clone());
212        token_info.update_active_time();
213        token_info.refresh_token = Some(refresh_token.to_string());
214        if self.config.refresh_token_timeout > 0 {
215            token_info.refresh_token_expire_time = Some(
216                Utc::now() + Duration::seconds(self.config.refresh_token_timeout),
217            );
218        }
219        if let Some(extra) = &extra_data {
220            token_info.extra_data = Some(extra.clone());
221        }
222        if token_info.expire_time.is_none()
223            && let Some(timeout) = self.config.timeout_duration()
224        {
225            token_info.expire_time =
226                Some(Utc::now() + Duration::from_std(timeout).unwrap());
227        }
228
229        let token_key = self.config.make_key("token:", new_access_token.as_str());
230        let token_json = serde_json::to_string(&token_info)
231            .map_err(SaTokenError::SerializationError)?;
232        self.storage
233            .set(&token_key, &token_json, self.config.timeout_duration())
234            .await
235            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
236
237        // 更新 login_id -> token 映射
238        let login_token_key = self.config.make_key("login:token:", &login_id);
239        self.storage
240            .set(
241                &login_token_key,
242                new_access_token.as_str(),
243                self.config.timeout_duration(),
244            )
245            .await
246            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
247
248        value["access_token"] = serde_json::json!(new_access_token.as_str());
249        value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
250
251        let ttl = if self.config.refresh_token_timeout > 0 {
252            Some(std::time::Duration::from_secs(
253                self.config.refresh_token_timeout as u64,
254            ))
255        } else {
256            None
257        };
258
259        self.storage
260            .set(&key, &value.to_string(), ttl)
261            .await
262            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
263
264        Ok((new_access_token, login_id))
265    }
266
267    /// Delete refresh token | 删除 refresh token
268    pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
269        let key = self.refresh_key(refresh_token);
270
271        // 读取 login_id 以便清理用户索引
272        if let Ok(Some(value_str)) = self.storage.get(&key).await {
273            if let Ok(value) = serde_json::from_str::<serde_json::Value>(&value_str)
274                && let Some(login_id) = value["login_id"].as_str()
275            {
276                let _ = self.remove_user_index(login_id, refresh_token).await;
277            }
278        }
279
280        self.storage
281            .delete(&key)
282            .await
283            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
284        Ok(())
285    }
286
287    /// Get all refresh tokens for a user | 获取用户的所有 refresh token
288    pub async fn get_user_refresh_tokens(&self, login_id: &str) -> SaTokenResult<Vec<String>> {
289        self.load_string_list(&self.user_index_key(login_id)).await
290    }
291
292    /// Revoke all refresh tokens for a user | 撤销用户的所有 refresh token
293    pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
294        let tokens = self.get_user_refresh_tokens(login_id).await?;
295        for token in tokens {
296            self.delete(&token).await?;
297        }
298        Ok(())
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use sa_token_storage_memory::MemoryStorage;
306    use crate::config::TokenStyle;
307
308    fn create_test_config() -> Arc<SaTokenConfig> {
309        Arc::new(SaTokenConfig {
310            token_style: TokenStyle::Uuid,
311            timeout: 3600,
312            refresh_token_timeout: 7200,
313            enable_refresh_token: true,
314            ..Default::default()
315        })
316    }
317
318    #[tokio::test]
319    async fn test_refresh_token_generation() {
320        let storage = Arc::new(MemoryStorage::new());
321        let config = create_test_config();
322        let refresh_mgr = RefreshTokenManager::new(storage, config);
323
324        let token1 = refresh_mgr.generate("user_123");
325        let token2 = refresh_mgr.generate("user_123");
326
327        assert_ne!(token1, token2);
328        assert!(token1.starts_with("refresh_"));
329    }
330
331    #[tokio::test]
332    async fn test_refresh_token_store_and_validate() {
333        let storage = Arc::new(MemoryStorage::new());
334        let config = create_test_config();
335        let refresh_mgr = RefreshTokenManager::new(storage, config);
336
337        let refresh_token = refresh_mgr.generate("user_123");
338        let access_token = "access_token_123";
339
340        refresh_mgr
341            .store(&refresh_token, access_token, "user_123")
342            .await
343            .unwrap();
344
345        let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
346        assert_eq!(login_id, "user_123");
347
348        let tokens = refresh_mgr.get_user_refresh_tokens("user_123").await.unwrap();
349        assert_eq!(tokens, vec![refresh_token]);
350    }
351
352    #[tokio::test]
353    async fn test_refresh_access_token() {
354        let storage = Arc::new(MemoryStorage::new());
355        let config = create_test_config();
356        let refresh_mgr = RefreshTokenManager::new(storage.clone(), config.clone());
357
358        let refresh_token = refresh_mgr.generate("user_123");
359        let old_access_token = "old_access_token";
360
361        refresh_mgr
362            .store(&refresh_token, old_access_token, "user_123")
363            .await
364            .unwrap();
365
366        let (new_access_token, login_id) = refresh_mgr
367            .refresh_access_token(&refresh_token)
368            .await
369            .unwrap();
370
371        assert_eq!(login_id, "user_123");
372        assert_ne!(new_access_token.as_str(), old_access_token);
373
374        // 新 access token 应已写入 token 存储
375        let token_key = config.make_key("token:", new_access_token.as_str());
376        let stored = storage.get(&token_key).await.unwrap();
377        assert!(stored.is_some());
378    }
379
380    #[tokio::test]
381    async fn test_delete_refresh_token() {
382        let storage = Arc::new(MemoryStorage::new());
383        let config = create_test_config();
384        let refresh_mgr = RefreshTokenManager::new(storage, config);
385
386        let refresh_token = refresh_mgr.generate("user_123");
387        refresh_mgr
388            .store(&refresh_token, "access", "user_123")
389            .await
390            .unwrap();
391
392        refresh_mgr.delete(&refresh_token).await.unwrap();
393
394        let result = refresh_mgr.validate(&refresh_token).await;
395        assert!(result.is_err());
396
397        let tokens = refresh_mgr.get_user_refresh_tokens("user_123").await.unwrap();
398        assert!(tokens.is_empty());
399    }
400
401    #[tokio::test]
402    async fn test_revoke_all_for_user() {
403        let storage = Arc::new(MemoryStorage::new());
404        let config = create_test_config();
405        let refresh_mgr = RefreshTokenManager::new(storage, config);
406
407        let rt1 = refresh_mgr.generate("user_123");
408        let rt2 = refresh_mgr.generate("user_123");
409        refresh_mgr.store(&rt1, "a1", "user_123").await.unwrap();
410        refresh_mgr.store(&rt2, "a2", "user_123").await.unwrap();
411
412        refresh_mgr.revoke_all_for_user("user_123").await.unwrap();
413
414        assert!(refresh_mgr.validate(&rt1).await.is_err());
415        assert!(refresh_mgr.validate(&rt2).await.is_err());
416        assert!(
417            refresh_mgr
418                .get_user_refresh_tokens("user_123")
419                .await
420                .unwrap()
421                .is_empty()
422        );
423    }
424}