1use std::sync::Arc;
9use chrono::{DateTime, Utc, Duration};
10use sa_token_adapter::storage::SaStorage;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::TokenValue;
13use crate::token::TokenGenerator;
14use crate::config::SaTokenConfig;
15use uuid::Uuid;
16
17#[derive(Clone)]
22pub struct RefreshTokenManager {
23 storage: Arc<dyn SaStorage>,
24 config: Arc<SaTokenConfig>,
25}
26
27impl RefreshTokenManager {
28 pub fn new(storage: Arc<dyn SaStorage>, config: Arc<SaTokenConfig>) -> Self {
35 Self { storage, config }
36 }
37
38 pub fn generate(&self, login_id: &str) -> String {
48 format!(
50 "refresh_{}_{}_{}",
51 Utc::now().timestamp_millis(),
52 login_id,
53 Uuid::new_v4().simple()
54 )
55 }
56
57 pub async fn store(
65 &self,
66 refresh_token: &str,
67 access_token: &str,
68 login_id: &str,
69 ) -> SaTokenResult<()> {
70 let key = format!("sa:refresh:{}", refresh_token);
71 let expire_time = if self.config.refresh_token_timeout > 0 {
72 Some(Utc::now() + Duration::seconds(self.config.refresh_token_timeout))
73 } else {
74 None
75 };
76
77 let value = serde_json::json!({
78 "access_token": access_token,
79 "login_id": login_id,
80 "created_at": Utc::now().to_rfc3339(),
81 "expire_time": expire_time.map(|t| t.to_rfc3339()),
82 }).to_string();
83
84 let ttl = if self.config.refresh_token_timeout > 0 {
85 Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
86 } else {
87 None
88 };
89
90 self.storage.set(&key, &value, ttl)
91 .await
92 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
93
94 Ok(())
95 }
96
97 pub async fn validate(&self, refresh_token: &str) -> SaTokenResult<String> {
107 let key = format!("sa:refresh:{}", refresh_token);
108
109 let value_str = self.storage.get(&key)
110 .await
111 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
112 .ok_or_else(|| SaTokenError::RefreshTokenNotFound)?;
113
114 let value: serde_json::Value = serde_json::from_str(&value_str)
115 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
116
117 let login_id = value["login_id"].as_str()
118 .ok_or_else(|| SaTokenError::RefreshTokenMissingLoginId)?
119 .to_string();
120
121 if let Some(expire_str) = value["expire_time"].as_str() {
123 let expire_time = DateTime::parse_from_rfc3339(expire_str)
124 .map_err(|_| SaTokenError::RefreshTokenInvalidExpireTime)?
125 .with_timezone(&Utc);
126
127 if Utc::now() > expire_time {
128 self.delete(refresh_token).await?;
130 return Err(SaTokenError::TokenExpired);
131 }
132 }
133
134 Ok(login_id)
135 }
136
137 pub async fn refresh_access_token(
147 &self,
148 refresh_token: &str,
149 ) -> SaTokenResult<(TokenValue, String)> {
150 let login_id = self.validate(refresh_token).await?;
152
153 let new_access_token = TokenGenerator::generate_with_login_id(&self.config, &login_id);
155
156 let key = format!("sa:refresh:{}", refresh_token);
158 let value_str = self.storage.get(&key)
159 .await
160 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
161 .ok_or_else(|| SaTokenError::RefreshTokenNotFound)?;
162
163 let mut value: serde_json::Value = serde_json::from_str(&value_str)
164 .map_err(|_| SaTokenError::RefreshTokenInvalidData)?;
165
166 value["access_token"] = serde_json::json!(new_access_token.as_str());
167 value["refreshed_at"] = serde_json::json!(Utc::now().to_rfc3339());
168
169 let ttl = if self.config.refresh_token_timeout > 0 {
170 Some(std::time::Duration::from_secs(self.config.refresh_token_timeout as u64))
171 } else {
172 None
173 };
174
175 self.storage.set(&key, &value.to_string(), ttl)
176 .await
177 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
178
179 Ok((new_access_token, login_id))
180 }
181
182 pub async fn delete(&self, refresh_token: &str) -> SaTokenResult<()> {
188 let key = format!("sa:refresh:{}", refresh_token);
189 self.storage.delete(&key)
190 .await
191 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
192 Ok(())
193 }
194
195 pub async fn get_user_refresh_tokens(&self, _login_id: &str) -> SaTokenResult<Vec<String>> {
200 Ok(vec![])
205 }
206
207 pub async fn revoke_all_for_user(&self, login_id: &str) -> SaTokenResult<()> {
213 let tokens = self.get_user_refresh_tokens(login_id).await?;
214 for token in tokens {
215 self.delete(&token).await?;
216 }
217 Ok(())
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use sa_token_storage_memory::MemoryStorage;
225 use crate::config::TokenStyle;
226
227 fn create_test_config() -> Arc<SaTokenConfig> {
228 Arc::new(SaTokenConfig {
229 token_style: TokenStyle::Uuid,
230 timeout: 3600,
231 refresh_token_timeout: 7200,
232 enable_refresh_token: true,
233 ..Default::default()
234 })
235 }
236
237 #[tokio::test]
238 async fn test_refresh_token_generation() {
239 let storage = Arc::new(MemoryStorage::new());
240 let config = create_test_config();
241 let refresh_mgr = RefreshTokenManager::new(storage, config);
242
243 let token1 = refresh_mgr.generate("user_123");
244 let token2 = refresh_mgr.generate("user_123");
245
246 assert_ne!(token1, token2);
247 assert!(token1.starts_with("refresh_"));
248 }
249
250 #[tokio::test]
251 async fn test_refresh_token_store_and_validate() {
252 let storage = Arc::new(MemoryStorage::new());
253 let config = create_test_config();
254 let refresh_mgr = RefreshTokenManager::new(storage, config);
255
256 let refresh_token = refresh_mgr.generate("user_123");
257 let access_token = "access_token_123";
258
259 refresh_mgr.store(&refresh_token, access_token, "user_123").await.unwrap();
261
262 let login_id = refresh_mgr.validate(&refresh_token).await.unwrap();
264 assert_eq!(login_id, "user_123");
265 }
266
267 #[tokio::test]
268 async fn test_refresh_access_token() {
269 let storage = Arc::new(MemoryStorage::new());
270 let config = create_test_config();
271 let refresh_mgr = RefreshTokenManager::new(storage, config);
272
273 let refresh_token = refresh_mgr.generate("user_123");
274 let old_access_token = "old_access_token";
275
276 refresh_mgr.store(&refresh_token, old_access_token, "user_123").await.unwrap();
278
279 let (new_access_token, login_id) = refresh_mgr.refresh_access_token(&refresh_token).await.unwrap();
281
282 assert_eq!(login_id, "user_123");
283 assert_ne!(new_access_token.as_str(), old_access_token);
284 }
285
286 #[tokio::test]
287 async fn test_delete_refresh_token() {
288 let storage = Arc::new(MemoryStorage::new());
289 let config = create_test_config();
290 let refresh_mgr = RefreshTokenManager::new(storage, config);
291
292 let refresh_token = refresh_mgr.generate("user_123");
293 refresh_mgr.store(&refresh_token, "access", "user_123").await.unwrap();
294
295 refresh_mgr.delete(&refresh_token).await.unwrap();
297
298 let result = refresh_mgr.validate(&refresh_token).await;
300 assert!(result.is_err());
301 }
302}
303