1use std::sync::Arc;
6use std::collections::HashMap;
7use chrono::{Duration, Utc};
8use tokio::sync::RwLock;
9use sa_token_adapter::storage::SaStorage;
10use crate::config::SaTokenConfig;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::session::SaSession;
14use crate::event::{SaTokenEventBus, SaTokenEvent};
15use crate::online::OnlineManager;
16use crate::distributed::DistributedSessionManager;
17
18#[derive(Clone)]
20pub struct SaTokenManager {
21 pub(crate) storage: Arc<dyn SaStorage>,
22 pub config: SaTokenConfig,
23 pub(crate) user_permissions: Arc<RwLock<HashMap<String, Vec<String>>>>,
25 pub(crate) user_roles: Arc<RwLock<HashMap<String, Vec<String>>>>,
27 pub(crate) event_bus: SaTokenEventBus,
29 online_manager: Option<Arc<OnlineManager>>,
31 distributed_manager: Option<Arc<DistributedSessionManager>>,
33}
34
35impl SaTokenManager {
36 pub fn new(storage: Arc<dyn SaStorage>, config: SaTokenConfig) -> Self {
38 Self {
39 storage,
40 config,
41 user_permissions: Arc::new(RwLock::new(HashMap::new())),
42 user_roles: Arc::new(RwLock::new(HashMap::new())),
43 event_bus: SaTokenEventBus::new(),
44 online_manager: None,
45 distributed_manager: None,
46 }
47 }
48
49 pub fn with_online_manager(mut self, manager: Arc<OnlineManager>) -> Self {
50 self.online_manager = Some(manager);
51 self
52 }
53
54 pub fn with_distributed_manager(mut self, manager: Arc<DistributedSessionManager>) -> Self {
55 self.distributed_manager = Some(manager);
56 self
57 }
58
59 pub fn online_manager(&self) -> Option<&Arc<OnlineManager>> {
60 self.online_manager.as_ref()
61 }
62
63 pub fn distributed_manager(&self) -> Option<&Arc<DistributedSessionManager>> {
64 self.distributed_manager.as_ref()
65 }
66
67 pub fn event_bus(&self) -> &SaTokenEventBus {
69 &self.event_bus
70 }
71
72 pub async fn login(&self, login_id: impl Into<String>) -> SaTokenResult<TokenValue> {
74 let login_id = login_id.into();
75
76 let token = TokenGenerator::generate_with_login_id(&self.config, &login_id);
78
79 let mut token_info = TokenInfo::new(token.clone(), login_id.clone());
81 token_info.login_type = "default".to_string();
82
83 if let Some(timeout) = self.config.timeout_duration() {
85 token_info.expire_time = Some(Utc::now() + Duration::from_std(timeout).unwrap());
86 }
87
88 let key = format!("sa:token:{}", token.as_str());
90 let value = serde_json::to_string(&token_info)
91 .map_err(|e| SaTokenError::SerializationError(e))?;
92
93 self.storage.set(&key, &value, self.config.timeout_duration()).await
94 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
95
96 let login_token_key = format!("sa:login:token:{}", login_id);
98 self.storage.set(&login_token_key, token.as_str(), self.config.timeout_duration()).await
99 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
100
101 if !self.config.is_concurrent {
103 self.logout_by_login_id(&login_id).await?;
104 }
105
106 let event = SaTokenEvent::login(login_id.clone(), token.as_str())
108 .with_login_type(&token_info.login_type);
109 self.event_bus.publish(event).await;
110
111 Ok(token)
112 }
113
114 pub async fn logout(&self, token: &TokenValue) -> SaTokenResult<()> {
116 let key = format!("sa:token:{}", token.as_str());
118 let token_info_str = self.storage.get(&key).await
119 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
120
121 let token_info = if let Some(value) = token_info_str {
122 serde_json::from_str::<TokenInfo>(&value).ok()
123 } else {
124 None
125 };
126
127 self.storage.delete(&key).await
129 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
130
131 if let Some(info) = token_info.clone() {
133 let event = SaTokenEvent::logout(&info.login_id, token.as_str())
134 .with_login_type(&info.login_type);
135 self.event_bus.publish(event).await;
136
137 if let Some(online_mgr) = &self.online_manager {
139 online_mgr.mark_offline(&info.login_id, token.as_str()).await;
140 }
141 }
142
143 Ok(())
144 }
145
146 pub async fn logout_by_login_id(&self, login_id: &str) -> SaTokenResult<()> {
148 let token_prefix = "sa:token:";
150
151 if let Ok(keys) = self.storage.keys(&format!("{}*", token_prefix)).await {
153 for key in keys {
155 if let Ok(Some(token_info_str)) = self.storage.get(&key).await {
157 if let Ok(token_info) = serde_json::from_str::<TokenInfo>(&token_info_str) {
159 if token_info.login_id == login_id {
161 let token_str = key[token_prefix.len()..].to_string();
163 let token = TokenValue::new(token_str);
164
165 let _ = self.logout(&token).await;
167 }
168 }
169 }
170 }
171 }
172
173 Ok(())
174 }
175
176 pub async fn get_token_info(&self, token: &TokenValue) -> SaTokenResult<TokenInfo> {
178 let key = format!("sa:token:{}", token.as_str());
179 let value = self.storage.get(&key).await
180 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
181 .ok_or(SaTokenError::TokenNotFound)?;
182
183 let token_info: TokenInfo = serde_json::from_str(&value)
184 .map_err(|e| SaTokenError::SerializationError(e))?;
185
186 if token_info.is_expired() {
188 self.logout(token).await?;
190 return Err(SaTokenError::TokenExpired);
191 }
192
193 if self.config.auto_renew {
196 let renew_timeout = if self.config.active_timeout > 0 {
197 self.config.active_timeout
198 } else {
199 self.config.timeout
200 };
201
202 let _ = self.renew_timeout_internal(token, renew_timeout, &token_info).await;
204 }
205
206 Ok(token_info)
207 }
208
209 pub async fn is_valid(&self, token: &TokenValue) -> bool {
211 self.get_token_info(token).await.is_ok()
212 }
213
214 pub async fn get_session(&self, login_id: &str) -> SaTokenResult<SaSession> {
216 let key = format!("sa:session:{}", login_id);
217 let value = self.storage.get(&key).await
218 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
219
220 if let Some(value) = value {
221 let session: SaSession = serde_json::from_str(&value)
222 .map_err(|e| SaTokenError::SerializationError(e))?;
223 Ok(session)
224 } else {
225 Ok(SaSession::new(login_id))
226 }
227 }
228
229 pub async fn save_session(&self, session: &SaSession) -> SaTokenResult<()> {
231 let key = format!("sa:session:{}", session.id);
232 let value = serde_json::to_string(session)
233 .map_err(|e| SaTokenError::SerializationError(e))?;
234
235 self.storage.set(&key, &value, None).await
236 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
237
238 Ok(())
239 }
240
241 pub async fn delete_session(&self, login_id: &str) -> SaTokenResult<()> {
243 let key = format!("sa:session:{}", login_id);
244 self.storage.delete(&key).await
245 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
246 Ok(())
247 }
248
249 pub async fn renew_timeout(
251 &self,
252 token: &TokenValue,
253 timeout_seconds: i64,
254 ) -> SaTokenResult<()> {
255 let token_info = self.get_token_info(token).await?;
256 self.renew_timeout_internal(token, timeout_seconds, &token_info).await
257 }
258
259 async fn renew_timeout_internal(
261 &self,
262 token: &TokenValue,
263 timeout_seconds: i64,
264 token_info: &TokenInfo,
265 ) -> SaTokenResult<()> {
266 let mut new_token_info = token_info.clone();
267
268 use chrono::{Utc, Duration};
270 let new_expire_time = Utc::now() + Duration::seconds(timeout_seconds);
271 new_token_info.expire_time = Some(new_expire_time);
272
273 let key = format!("sa:token:{}", token.as_str());
275 let value = serde_json::to_string(&new_token_info)
276 .map_err(|e| SaTokenError::SerializationError(e))?;
277
278 let timeout = std::time::Duration::from_secs(timeout_seconds as u64);
279 self.storage.set(&key, &value, Some(timeout)).await
280 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
281
282 Ok(())
283 }
284
285 pub async fn kick_out(&self, login_id: &str) -> SaTokenResult<()> {
287 let token_result = self.storage.get(&format!("sa:login:token:{}", login_id)).await;
288
289 if let Some(online_mgr) = &self.online_manager {
290 let _ = online_mgr.kick_out_notify(login_id, "Account kicked out".to_string()).await;
291 }
292
293 self.logout_by_login_id(login_id).await?;
294 self.delete_session(login_id).await?;
295
296 if let Ok(Some(token_str)) = token_result {
297 let event = SaTokenEvent::kick_out(login_id, token_str);
298 self.event_bus.publish(event).await;
299 }
300
301 Ok(())
302 }
303}