1use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::config::SaTokenConfig;
14use uuid::Uuid;
15
16#[derive(Clone)]
21pub struct RefreshTokenManager {
22 storage: Arc<dyn SaStorage>,
23 config: Arc<SaTokenConfig>,
24}
25
26impl RefreshTokenManager {
27 pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
29 Self { storage, config }
30 }
31
32 fn refresh_key(&self, refresh_token: &str) -> String {
34 self.config.make_key("refresh:", refresh_token)
35 }
36
37 fn user_index_key(&self, login_id: &str) -> String {
39 self.config.make_key("refresh:user:", login_id)
40 }
41
42 async fn load_string_list(&self, key: &str) -> SaTokenResult<Vec<String>> {
43 match self
44 .storage
45 .get(key)
46 .await
47 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
48 {
49 Some(value) => serde_json::from_str(&value).map_err(SaTokenError::SerializationError),
50 None => Ok(Vec::new()),
51 }
52 }
53
54 async fn save_string_list(&self, key: &str, list: &[String]) -> SaTokenResult<()> {
55 let value = serde_json::to_string(list).map_err(SaTokenError::SerializationError)?;
56 self.storage
57 .set(key, &value, None)
58 .await
59 .map_err(|e| SaTokenError::StorageError(e.to_string()))
60 }
61
62 async fn append_user_index(&self, login_id: &str, refresh_token: &str) -> SaTokenResult<()> {
64 let key = self.user_index_key(login_id);
65 let mut list = self.load_string_list(&key).await?;
66 if !list.iter().any(|t| t == refresh_token) {
67 list.push(refresh_token.to_string());
68 self.save_string_list(&key, &list).await?;
69 }
70 Ok(())
71 }
72
73 async fn remove_user_index(&self, login_id: &str, refresh_token: &str) -> SaTokenResult<()> {
75 let key = self.user_index_key(login_id);
76 let mut list = self.load_string_list(&key).await?;
77 let before = list.len();
78 list.retain(|t| t != refresh_token);
79 if list.len() != before {
80 self.save_string_list(&key, &list).await?;
81 }
82 Ok(())
83 }
84
85 pub fn generate(&self, login_id: &str) -> String {
87 format!(
88 "refresh_{}_{}_{}",
89 Utc::now().timestamp_millis(),
90 login_id,
91 Uuid::new_v4().simple()
92 )
93 }
94
95 pub async fn store(
97 &self,
98 refresh_token: &str,
99 access_token: &str,
100 login_id: &str,
101 ) -> SaTokenResult<()> {
102 self.store_with_extra(refresh_token, access_token, login_id, None)
103 .await
104 }
105
106 pub async fn store_with_extra(
108 &self,
109 refresh_token: &str,
110 access_token: &str,
111 login_id: &str,
112 extra_data: Option<&serde_json::Value>,
113 ) -> SaTokenResult<()> {
114 let key = self.refresh_key(refresh_token);
115 let expire_time = if self.config.refresh_token_timeout > 0 {
116 Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
117 } else {
118 None
119 };
120
121 let mut obj = serde_json::json!({
122 "access_token": access_token,
123 "login_id": login_id,
124 "created_at": Utc::now().to_rfc3339(),
125 "expire_time": expire_time.map(|t| t.to_rfc3339()),
126 });
127 if let Some(extra) = extra_data {
128 obj["extra_data"] = extra.clone();
129 }
130 let value = obj.to_string();
131
132 let ttl = if self.config.refresh_token_timeout > 0 {
133 Some(std::time::Duration::from_secs(
134 self.config.refresh_token_timeout as u64,
135 ))
136 } else {
137 None
138 };
139
140 self.storage
141 .set(&key, &value, ttl)
142 .await
143 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
144
145 self.append_user_index(login_id, refresh_token).await?;
146 Ok(())
147 }
148
149 pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
151 let key = self.refresh_key(refresh_token);
152
153 let value_str = self
154 .storage
155 .get(&key)
156 .await
157 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
158 .ok_or(SaTokenError::RefreshTokenNotFound)?;
159
160 let value: serde_json::Value = serde_json::from_str(&value_str)
161 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
162
163 let login_id = value["login_id"]
164 .as_str()
165 .ok_or(SaTokenError::RefreshTokenMissingLoginId)?
166 .to_string();
167
168 if let Some(expire_str) = value["expire_time"].as_str() {
169 let expire_time = DateTime::parse_from_rfc3339(expire_str)
170 .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
171 .with_timezone(&Utc);
172
173 if Utc::now() > expire_time {
174 self.delete(refresh_token).await?;
175 return Err(SaTokenError::TokenExpired);
176 }
177 }
178
179 Ok(login_id)
180 }
181
182 pub async fn refresh_access_token(
186 &self,
187 refresh_token: &str,
188 ) -> SaTokenResult<(TokenValue, String)> {
189 let login_id = self.validate(refresh_token).await?;
190
191 let key = self.refresh_key(refresh_token);
192 let value_str = self
193 .storage
194 .get(&key)
195 .await
196 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
197 .ok_or(SaTokenError::RefreshTokenNotFound)?;
198
199 let mut value: serde_json::Value = serde_json::from_str(&value_str)
200 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
201
202 let extra_data = value.get("extra_data").cloned();
203 let new_access_token = match &extra_data {
204 Some(extra) => {
205 TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra)
206 }
207 None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
208 };
209
210 let mut token_info = TokenInfo::new(new_access_token.clone(), login_id.clone());
212 token_info.update_active_time();
213 token_info.refresh_token = Some(refresh_token.to_string());
214 if self.config.refresh_token_timeout > 0 {
215 token_info.refresh_token_expire_time = Some(
216 Utc::now() + Duration::seconds(self.config.refresh_token_timeout),
217 );
218 }
219 if let Some(extra) = &extra_data {
220 token_info.extra_data = Some(extra.clone());
221 }
222 if token_info.expire_time.is_none()
223 && let Some(timeout) = self.config.timeout_duration()
224 {
225 token_info.expire_time =
226 Some(Utc::now() + Duration::from_std(timeout).unwrap());
227 }
228
229 let token_key = self.config.make_key("token:", new_access_token.as_str());
230 let token_json = serde_json::to_string(&token_info)
231 .map_err(SaTokenError::SerializationError)?;
232 self.storage
233 .set(&token_key, &token_json, self.config.timeout_duration())
234 .await
235 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
236
237 let login_token_key = self.config.make_key("login:token:", &login_id);
239 self.storage
240 .set(
241 &login_token_key,
242 new_access_token.as_str(),
243 self.config.timeout_duration(),
244 )
245 .await
246 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
247
248 value["access_token"] = serde_json::json!(new_access_token.as_str());
249 value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
250
251 let ttl = if self.config.refresh_token_timeout > 0 {
252 Some(std::time::Duration::from_secs(
253 self.config.refresh_token_timeout as u64,
254 ))
255 } else {
256 None
257 };
258
259 self.storage
260 .set(&key, &value.to_string(), ttl)
261 .await
262 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
263
264 Ok((new_access_token, login_id))
265 }
266
267 pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
269 let key = self.refresh_key(refresh_token);
270
271 if let Ok(Some(value_str)) = self.storage.get(&key).await {
273 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&value_str)
274 && let Some(login_id) = value["login_id"].as_str()
275 {
276 let _ = self.remove_user_index(login_id, refresh_token).await;
277 }
278 }
279
280 self.storage
281 .delete(&key)
282 .await
283 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
284 Ok(())
285 }
286
287 pub async fn get_user_refresh_tokens(&self, login_id: &str) -> SaTokenResult<Vec<String>> {
289 self.load_string_list(&self.user_index_key(login_id)).await
290 }
291
292 pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
294 let tokens = self.get_user_refresh_tokens(login_id).await?;
295 for token in tokens {
296 self.delete(&token).await?;
297 }
298 Ok(())
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use sa_token_storage_memory::MemoryStorage;
306 use crate::config::TokenStyle;
307
308 fn create_test_config() -> Arc<SaTokenConfig> {
309 Arc::new(SaTokenConfig {
310 token_style: TokenStyle::Uuid,
311 timeout: 3600,
312 refresh_token_timeout: 7200,
313 enable_refresh_token: true,
314 ..Default::default()
315 })
316 }
317
318 #[tokio::test]
319 async fn test_refresh_token_generation() {
320 let storage = Arc::new(MemoryStorage::new());
321 let config = create_test_config();
322 let refresh_mgr = RefreshTokenManager::new(storage, config);
323
324 let token1 = refresh_mgr.generate("user_123");
325 let token2 = refresh_mgr.generate("user_123");
326
327 assert_ne!(token1, token2);
328 assert!(token1.starts_with("refresh_"));
329 }
330
331 #[tokio::test]
332 async fn test_refresh_token_store_and_validate() {
333 let storage = Arc::new(MemoryStorage::new());
334 let config = create_test_config();
335 let refresh_mgr = RefreshTokenManager::new(storage, config);
336
337 let refresh_token = refresh_mgr.generate("user_123");
338 let access_token = "access_token_123";
339
340 refresh_mgr
341 .store(&refresh_token, access_token, "user_123")
342 .await
343 .unwrap();
344
345 let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
346 assert_eq!(login_id, "user_123");
347
348 let tokens = refresh_mgr.get_user_refresh_tokens("user_123").await.unwrap();
349 assert_eq!(tokens, vec![refresh_token]);
350 }
351
352 #[tokio::test]
353 async fn test_refresh_access_token() {
354 let storage = Arc::new(MemoryStorage::new());
355 let config = create_test_config();
356 let refresh_mgr = RefreshTokenManager::new(storage.clone(), config.clone());
357
358 let refresh_token = refresh_mgr.generate("user_123");
359 let old_access_token = "old_access_token";
360
361 refresh_mgr
362 .store(&refresh_token, old_access_token, "user_123")
363 .await
364 .unwrap();
365
366 let (new_access_token, login_id) = refresh_mgr
367 .refresh_access_token(&refresh_token)
368 .await
369 .unwrap();
370
371 assert_eq!(login_id, "user_123");
372 assert_ne!(new_access_token.as_str(), old_access_token);
373
374 let token_key = config.make_key("token:", new_access_token.as_str());
376 let stored = storage.get(&token_key).await.unwrap();
377 assert!(stored.is_some());
378 }
379
380 #[tokio::test]
381 async fn test_delete_refresh_token() {
382 let storage = Arc::new(MemoryStorage::new());
383 let config = create_test_config();
384 let refresh_mgr = RefreshTokenManager::new(storage, config);
385
386 let refresh_token = refresh_mgr.generate("user_123");
387 refresh_mgr
388 .store(&refresh_token, "access", "user_123")
389 .await
390 .unwrap();
391
392 refresh_mgr.delete(&refresh_token).await.unwrap();
393
394 let result = refresh_mgr.validate(&refresh_token).await;
395 assert!(result.is_err());
396
397 let tokens = refresh_mgr.get_user_refresh_tokens("user_123").await.unwrap();
398 assert!(tokens.is_empty());
399 }
400
401 #[tokio::test]
402 async fn test_revoke_all_for_user() {
403 let storage = Arc::new(MemoryStorage::new());
404 let config = create_test_config();
405 let refresh_mgr = RefreshTokenManager::new(storage, config);
406
407 let rt1 = refresh_mgr.generate("user_123");
408 let rt2 = refresh_mgr.generate("user_123");
409 refresh_mgr.store(&rt1, "a1", "user_123").await.unwrap();
410 refresh_mgr.store(&rt2, "a2", "user_123").await.unwrap();
411
412 refresh_mgr.revoke_all_for_user("user_123").await.unwrap();
413
414 assert!(refresh_mgr.validate(&rt1).await.is_err());
415 assert!(refresh_mgr.validate(&rt2).await.is_err());
416 assert!(
417 refresh_mgr
418 .get_user_refresh_tokens("user_123")
419 .await
420 .unwrap()
421 .is_empty()
422 );
423 }
424}