1use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::TokenValue;
13use crate::token::TokenGenerator;
14use crate::config::SaTokenConfig;
15use uuid::Uuid;
16
17#[derive(Clone)]
22pub struct RefreshTokenManager {
23 storage: Arc<dyn SaStorage>,
24 config: Arc<SaTokenConfig>,
25}
26
27impl RefreshTokenManager {
28 pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
35 Self { storage, config }
36 }
37
38 pub fn generate(&self, login_id: &str) -> String {
48 format!(
50 "refresh_{}_{}_{}",
51 Utc::now().timestamp_millis(),
52 login_id,
53 Uuid::new_v4().simple()
54 )
55 }
56
57 pub async fn store(
66 &self,
67 refresh_token: &str,
68 access_token: &str,
69 login_id: &str,
70 ) -> SaTokenResult<()> {
71 self.store_with_extra(refresh_token, access_token, login_id, None).await
72 }
73
74 pub async fn store_with_extra(
83 &self,
84 refresh_token: &str,
85 access_token: &str,
86 login_id: &str,
87 extra_data: Option<&serde_json::Value>,
88 ) -> SaTokenResult<()> {
89 let key = format!("sa:refresh:{}", refresh_token);
90 let expire_time = if self.config.refresh_token_timeout > 0 {
91 Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
92 } else {
93 None
94 };
95
96 let mut obj = serde_json::json!({
97 "access_token": access_token,
98 "login_id": login_id,
99 "created_at": Utc::now().to_rfc3339(),
100 "expire_time": expire_time.map(|t| t.to_rfc3339()),
101 });
102 if let Some(extra) = extra_data {
103 obj["extra_data"] = extra.clone();
104 }
105 let value = obj.to_string();
106
107 let ttl = if self.config.refresh_token_timeout > 0 {
108 Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
109 } else {
110 None
111 };
112
113 self.storage.set(&key, &value, ttl)
114 .await
115 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
116
117 Ok(())
118 }
119
120 pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
130 let key = format!("sa:refresh:{}", refresh_token);
131
132 let value_str = self.storage.get(&key)
133 .await
134 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
135 .ok_or(SaTokenError::RefreshTokenNotFound)?;
136
137 let value: serde_json::Value = serde_json::from_str(&value_str)
138 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
139
140 let login_id = value["login_id"].as_str()
141 .ok_or(SaTokenError::RefreshTokenMissingLoginId)?
142 .to_string();
143
144 if let Some(expire_str) = value["expire_time"].as_str() {
146 let expire_time = DateTime::parse_from_rfc3339(expire_str)
147 .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
148 .with_timezone(&Utc);
149
150 if Utc::now() > expire_time {
151 self.delete(refresh_token).await?;
153 return Err(SaTokenError::TokenExpired);
154 }
155 }
156
157 Ok(login_id)
158 }
159
160 pub async fn refresh_access_token(
170 &self,
171 refresh_token: &str,
172 ) -> SaTokenResult<(TokenValue, String)> {
173 let login_id = self.validate(refresh_token).await?;
175
176 let key = format!("sa:refresh:{}", refresh_token);
178 let value_str = self.storage.get(&key)
179 .await
180 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
181 .ok_or(SaTokenError::RefreshTokenNotFound)?;
182
183 let mut value: serde_json::Value = serde_json::from_str(&value_str)
184 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
185
186 let extra_data = value.get("extra_data").cloned();
188 let new_access_token = match &extra_data {
189 Some(extra) => TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra),
190 None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
191 };
192
193 value["access_token"] = serde_json::json!(new_access_token.as_str());
195 value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
196
197 let ttl = if self.config.refresh_token_timeout > 0 {
198 Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
199 } else {
200 None
201 };
202
203 self.storage.set(&key, &value.to_string(), ttl)
204 .await
205 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
206
207 Ok((new_access_token, login_id))
208 }
209
210 pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
216 let key = format!("sa:refresh:{}", refresh_token);
217 self.storage.delete(&key)
218 .await
219 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
220 Ok(())
221 }
222
223 pub async fn get_user_refresh_tokens(&self, _login_id: &str) -> SaTokenResult<Vec<String>> {
228 Ok(vec![])
233 }
234
235 pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
241 let tokens = self.get_user_refresh_tokens(login_id).await?;
242 for token in tokens {
243 self.delete(&token).await?;
244 }
245 Ok(())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use sa_token_storage_memory::MemoryStorage;
253 use crate::config::TokenStyle;
254
255 fn create_test_config() -> Arc<SaTokenConfig> {
256 Arc::new(SaTokenConfig {
257 token_style: TokenStyle::Uuid,
258 timeout: 3600,
259 refresh_token_timeout: 7200,
260 enable_refresh_token: true,
261 ..Default::default()
262 })
263 }
264
265 #[tokio::test]
266 async fn test_refresh_token_generation() {
267 let storage = Arc::new(MemoryStorage::new());
268 let config = create_test_config();
269 let refresh_mgr = RefreshTokenManager::new(storage, config);
270
271 let token1 = refresh_mgr.generate("user_123");
272 let token2 = refresh_mgr.generate("user_123");
273
274 assert_ne!(token1, token2);
275 assert!(token1.starts_with("refresh_"));
276 }
277
278 #[tokio::test]
279 async fn test_refresh_token_store_and_validate() {
280 let storage = Arc::new(MemoryStorage::new());
281 let config = create_test_config();
282 let refresh_mgr = RefreshTokenManager::new(storage, config);
283
284 let refresh_token = refresh_mgr.generate("user_123");
285 let access_token = "access_token_123";
286
287 refresh_mgr.store(&refresh_token, access_token, "user_123").await.unwrap();
289
290 let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
292 assert_eq!(login_id, "user_123");
293 }
294
295 #[tokio::test]
296 async fn test_refresh_access_token() {
297 let storage = Arc::new(MemoryStorage::new());
298 let config = create_test_config();
299 let refresh_mgr = RefreshTokenManager::new(storage, config);
300
301 let refresh_token = refresh_mgr.generate("user_123");
302 let old_access_token = "old_access_token";
303
304 refresh_mgr.store(&refresh_token, old_access_token, "user_123").await.unwrap();
306
307 let (new_access_token, login_id) = refresh_mgr.refresh_access_token(&refresh_token).await.unwrap();
309
310 assert_eq!(login_id, "user_123");
311 assert_ne!(new_access_token.as_str(), old_access_token);
312 }
313
314 #[tokio::test]
315 async fn test_delete_refresh_token() {
316 let storage = Arc::new(MemoryStorage::new());
317 let config = create_test_config();
318 let refresh_mgr = RefreshTokenManager::new(storage, config);
319
320 let refresh_token = refresh_mgr.generate("user_123");
321 refresh_mgr.store(&refresh_token, "access", "user_123").await.unwrap();
322
323 refresh_mgr.delete(&refresh_token).await.unwrap();
325
326 let result = refresh_mgr.validate(&refresh_token).await;
328 assert!(result.is_err());
329 }
330}
331