sa_token_core/
manager.rs

1// Author: 金书记
2//
3//! Token 管理器 - sa-token 的核心入口
4
5use std::sync::Arc;
6use std::collections::HashMap;
7use chrono::{Duration, Utc};
8use tokio::sync::RwLock;
9use sa_token_adapter::storage::SaStorage;
10use crate::config::SaTokenConfig;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::session::SaSession;
14use crate::event::{SaTokenEventBus, SaTokenEvent};
15use crate::online::OnlineManager;
16use crate::distributed::DistributedSessionManager;
17
18/// sa-token 管理器
19#[derive(Clone)]
20pub struct SaTokenManager {
21    pub(crate) storage: Arc<dyn SaStorage>,
22    pub config: SaTokenConfig,
23    /// 用户权限映射 user_id -> permissions
24    pub(crate) user_permissions: Arc<RwLock<HashMap<String, Vec<String>>>>,
25    /// 用户角色映射 user_id -> roles
26    pub(crate) user_roles: Arc<RwLock<HashMap<String, Vec<String>>>>,
27    /// 事件总线
28    pub(crate) event_bus: SaTokenEventBus,
29    /// 在线用户管理器
30    online_manager: Option<Arc<OnlineManager>>,
31    /// 分布式 Session 管理器
32    distributed_manager: Option<Arc<DistributedSessionManager>>,
33}
34
35impl SaTokenManager {
36    /// 创建新的管理器实例
37    pub fn new(storage: Arc<dyn SaStorage>, config: SaTokenConfig) -> Self {
38        Self { 
39            storage, 
40            config,
41            user_permissions: Arc::new(RwLock::new(HashMap::new())),
42            user_roles: Arc::new(RwLock::new(HashMap::new())),
43            event_bus: SaTokenEventBus::new(),
44            online_manager: None,
45            distributed_manager: None,
46        }
47    }
48    
49    pub fn with_online_manager(mut self, manager: Arc<OnlineManager>) -> Self {
50        self.online_manager = Some(manager);
51        self
52    }
53    
54    pub fn with_distributed_manager(mut self, manager: Arc<DistributedSessionManager>) -> Self {
55        self.distributed_manager = Some(manager);
56        self
57    }
58    
59    pub fn online_manager(&self) -> Option<&Arc<OnlineManager>> {
60        self.online_manager.as_ref()
61    }
62    
63    pub fn distributed_manager(&self) -> Option<&Arc<DistributedSessionManager>> {
64        self.distributed_manager.as_ref()
65    }
66    
67    /// 获取事件总线的引用
68    pub fn event_bus(&self) -> &SaTokenEventBus {
69        &self.event_bus
70    }
71    
72    /// 登录:为指定账号创建 token
73    pub async fn login(&self, login_id: impl Into<String>) -> SaTokenResult<TokenValue> {
74        let login_id = login_id.into();
75        
76        // 生成 token(支持 JWT)
77        let token = TokenGenerator::generate_with_login_id(&self.config, &login_id);
78        
79        // 创建 token 信息
80        let mut token_info = TokenInfo::new(token.clone(), login_id.clone());
81        token_info.login_type = "default".to_string();
82        
83        // 设置过期时间
84        if let Some(timeout) = self.config.timeout_duration() {
85            token_info.expire_time = Some(Utc::now() + Duration::from_std(timeout).unwrap());
86        }
87        
88        // 存储 token 信息
89        let key = format!("sa:token:{}", token.as_str());
90        let value = serde_json::to_string(&token_info)
91            .map_err(|e| SaTokenError::SerializationError(e))?;
92        
93        self.storage.set(&key, &value, self.config.timeout_duration()).await
94            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
95        
96        // 保存 login_id 到 token 的映射(用于根据 login_id 查找 token)
97        let login_token_key = format!("sa:login:token:{}", login_id);
98        self.storage.set(&login_token_key, token.as_str(), self.config.timeout_duration()).await
99            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
100        
101        // 如果不允许并发登录,踢掉之前的 token
102        if !self.config.is_concurrent {
103            self.logout_by_login_id(&login_id).await?;
104        }
105        
106        // 触发登录事件
107        let event = SaTokenEvent::login(login_id.clone(), token.as_str())
108            .with_login_type(&token_info.login_type);
109        self.event_bus.publish(event).await;
110        
111        Ok(token)
112    }
113    
114    /// 登出:删除指定 token
115    pub async fn logout(&self, token: &TokenValue) -> SaTokenResult<()> {
116        // 先从存储获取 token 信息,用于触发事件(不调用 get_token_info 避免递归)
117        let key = format!("sa:token:{}", token.as_str());
118        let token_info_str = self.storage.get(&key).await
119            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
120        
121        let token_info = if let Some(value) = token_info_str {
122            serde_json::from_str::<TokenInfo>(&value).ok()
123        } else {
124            None
125        };
126        
127        // 删除 token
128        self.storage.delete(&key).await
129            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
130        
131        // 触发登出事件
132        if let Some(info) = token_info.clone() {
133            let event = SaTokenEvent::logout(&info.login_id, token.as_str())
134                .with_login_type(&info.login_type);
135            self.event_bus.publish(event).await;
136            
137            // 如果有在线用户管理,通知用户下线
138            if let Some(online_mgr) = &self.online_manager {
139                online_mgr.mark_offline(&info.login_id, token.as_str()).await;
140            }
141        }
142        
143        Ok(())
144    }
145    
146    /// 根据登录 ID 登出所有 token
147    pub async fn logout_by_login_id(&self, login_id: &str) -> SaTokenResult<()> {
148        // 获取所有 token 键的前缀
149        let token_prefix = "sa:token:";
150        
151        // 获取所有 token 键
152        if let Ok(keys) = self.storage.keys(&format!("{}*", token_prefix)).await {
153            // 遍历所有 token 键
154            for key in keys {
155                // 获取 token 值
156                if let Ok(Some(token_info_str)) = self.storage.get(&key).await {
157                    // 反序列化 token 信息
158                    if let Ok(token_info) = serde_json::from_str::<TokenInfo>(&token_info_str) {
159                        // 如果 login_id 匹配,则登出该 token
160                        if token_info.login_id == login_id {
161                            // 提取 token 字符串(从键中移除前缀)
162                            let token_str = key[token_prefix.len()..].to_string();
163                            let token = TokenValue::new(token_str);
164                            
165                            // 调用登出方法(logout 方法内部会处理删除映射和在线用户管理)
166                            let _ = self.logout(&token).await;
167                        }
168                    }
169                }
170            }
171        }
172        
173        Ok(())
174    }
175    
176    /// 获取 token 信息
177    pub async fn get_token_info(&self, token: &TokenValue) -> SaTokenResult<TokenInfo> {
178        let key = format!("sa:token:{}", token.as_str());
179        let value = self.storage.get(&key).await
180            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
181            .ok_or(SaTokenError::TokenNotFound)?;
182        
183        let token_info: TokenInfo = serde_json::from_str(&value)
184            .map_err(|e| SaTokenError::SerializationError(e))?;
185        
186        // 检查是否过期
187        if token_info.is_expired() {
188            // 删除过期的 token
189            self.logout(token).await?;
190            return Err(SaTokenError::TokenExpired);
191        }
192        
193        // 如果开启了自动续签,则自动续签
194        // 注意:为了避免递归调用 get_token_info,这里直接更新过期时间
195        if self.config.auto_renew {
196            let renew_timeout = if self.config.active_timeout > 0 {
197                self.config.active_timeout
198            } else {
199                self.config.timeout
200            };
201            
202            // 直接续签(不递归调用 get_token_info)
203            let _ = self.renew_timeout_internal(token, renew_timeout, &token_info).await;
204        }
205        
206        Ok(token_info)
207    }
208    
209    /// 检查 token 是否有效
210    pub async fn is_valid(&self, token: &TokenValue) -> bool {
211        self.get_token_info(token).await.is_ok()
212    }
213    
214    /// 获取 session
215    pub async fn get_session(&self, login_id: &str) -> SaTokenResult<SaSession> {
216        let key = format!("sa:session:{}", login_id);
217        let value = self.storage.get(&key).await
218            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
219        
220        if let Some(value) = value {
221            let session: SaSession = serde_json::from_str(&value)
222                .map_err(|e| SaTokenError::SerializationError(e))?;
223            Ok(session)
224        } else {
225            Ok(SaSession::new(login_id))
226        }
227    }
228    
229    /// 保存 session
230    pub async fn save_session(&self, session: &SaSession) -> SaTokenResult<()> {
231        let key = format!("sa:session:{}", session.id);
232        let value = serde_json::to_string(session)
233            .map_err(|e| SaTokenError::SerializationError(e))?;
234        
235        self.storage.set(&key, &value, None).await
236            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
237        
238        Ok(())
239    }
240    
241    /// 删除 session
242    pub async fn delete_session(&self, login_id: &str) -> SaTokenResult<()> {
243        let key = format!("sa:session:{}", login_id);
244        self.storage.delete(&key).await
245            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
246        Ok(())
247    }
248    
249    /// 续期 token(重置过期时间)
250    pub async fn renew_timeout(
251        &self,
252        token: &TokenValue,
253        timeout_seconds: i64,
254    ) -> SaTokenResult<()> {
255        let token_info = self.get_token_info(token).await?;
256        self.renew_timeout_internal(token, timeout_seconds, &token_info).await
257    }
258    
259    /// 内部续期方法(避免递归调用 get_token_info)
260    async fn renew_timeout_internal(
261        &self,
262        token: &TokenValue,
263        timeout_seconds: i64,
264        token_info: &TokenInfo,
265    ) -> SaTokenResult<()> {
266        let mut new_token_info = token_info.clone();
267        
268        // 设置新的过期时间
269        use chrono::{Utc, Duration};
270        let new_expire_time = Utc::now() + Duration::seconds(timeout_seconds);
271        new_token_info.expire_time = Some(new_expire_time);
272        
273        // 保存更新后的 token 信息
274        let key = format!("sa:token:{}", token.as_str());
275        let value = serde_json::to_string(&new_token_info)
276            .map_err(|e| SaTokenError::SerializationError(e))?;
277        
278        let timeout = std::time::Duration::from_secs(timeout_seconds as u64);
279        self.storage.set(&key, &value, Some(timeout)).await
280            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
281        
282        Ok(())
283    }
284    
285    /// 踢人下线
286    pub async fn kick_out(&self, login_id: &str) -> SaTokenResult<()> {
287        let token_result = self.storage.get(&format!("sa:login:token:{}", login_id)).await;
288        
289        if let Some(online_mgr) = &self.online_manager {
290            let _ = online_mgr.kick_out_notify(login_id, "Account kicked out".to_string()).await;
291        }
292        
293        self.logout_by_login_id(login_id).await?;
294        self.delete_session(login_id).await?;
295        
296        if let Ok(Some(token_str)) = token_result {
297            let event = SaTokenEvent::kick_out(login_id, token_str);
298            self.event_bus.publish(event).await;
299        }
300        
301        Ok(())
302    }
303}