sa_token_core/
manager.rs

1// Author: 金书记
2//
3//! Token 管理器 - sa-token 的核心入口
4
5use std::sync::Arc;
6use std::collections::HashMap;
7use chrono::{Duration, Utc};
8use tokio::sync::RwLock;
9use sa_token_adapter::storage::SaStorage;
10use crate::config::SaTokenConfig;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::session::SaSession;
14use crate::event::{SaTokenEventBus, SaTokenEvent};
15
16/// sa-token 管理器
17#[derive(Clone)]
18pub struct SaTokenManager {
19    pub(crate) storage: Arc<dyn SaStorage>,
20    pub config: SaTokenConfig,
21    /// 用户权限映射 user_id -> permissions
22    pub(crate) user_permissions: Arc<RwLock<HashMap<String, Vec<String>>>>,
23    /// 用户角色映射 user_id -> roles
24    pub(crate) user_roles: Arc<RwLock<HashMap<String, Vec<String>>>>,
25    /// 事件总线
26    pub(crate) event_bus: SaTokenEventBus,
27}
28
29impl SaTokenManager {
30    /// 创建新的管理器实例
31    pub fn new(storage: Arc<dyn SaStorage>, config: SaTokenConfig) -> Self {
32        Self { 
33            storage, 
34            config,
35            user_permissions: Arc::new(RwLock::new(HashMap::new())),
36            user_roles: Arc::new(RwLock::new(HashMap::new())),
37            event_bus: SaTokenEventBus::new(),
38        }
39    }
40    
41    /// 获取事件总线的引用
42    pub fn event_bus(&self) -> &SaTokenEventBus {
43        &self.event_bus
44    }
45    
46    /// 登录:为指定账号创建 token
47    pub async fn login(&self, login_id: impl Into<String>) -> SaTokenResult<TokenValue> {
48        let login_id = login_id.into();
49        
50        // 生成 token(支持 JWT)
51        let token = TokenGenerator::generate_with_login_id(&self.config, &login_id);
52        
53        // 创建 token 信息
54        let mut token_info = TokenInfo::new(token.clone(), login_id.clone());
55        token_info.login_type = "default".to_string();
56        
57        // 设置过期时间
58        if let Some(timeout) = self.config.timeout_duration() {
59            token_info.expire_time = Some(Utc::now() + Duration::from_std(timeout).unwrap());
60        }
61        
62        // 存储 token 信息
63        let key = format!("sa:token:{}", token.as_str());
64        let value = serde_json::to_string(&token_info)
65            .map_err(|e| SaTokenError::SerializationError(e))?;
66        
67        self.storage.set(&key, &value, self.config.timeout_duration()).await
68            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
69        
70        // 保存 login_id 到 token 的映射(用于根据 login_id 查找 token)
71        let login_token_key = format!("sa:login:token:{}", login_id);
72        self.storage.set(&login_token_key, token.as_str(), self.config.timeout_duration()).await
73            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
74        
75        // 如果不允许并发登录,踢掉之前的 token
76        if !self.config.is_concurrent {
77            self.logout_by_login_id(&login_id).await?;
78        }
79        
80        // 触发登录事件
81        let event = SaTokenEvent::login(login_id.clone(), token.as_str())
82            .with_login_type(&token_info.login_type);
83        self.event_bus.publish(event).await;
84        
85        Ok(token)
86    }
87    
88    /// 登出:删除指定 token
89    pub async fn logout(&self, token: &TokenValue) -> SaTokenResult<()> {
90        // 先从存储获取 token 信息,用于触发事件(不调用 get_token_info 避免递归)
91        let key = format!("sa:token:{}", token.as_str());
92        let token_info_str = self.storage.get(&key).await
93            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
94        
95        let token_info = if let Some(value) = token_info_str {
96            serde_json::from_str::<TokenInfo>(&value).ok()
97        } else {
98            None
99        };
100        
101        // 删除 token
102        self.storage.delete(&key).await
103            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
104        
105        // 触发登出事件
106        if let Some(info) = token_info {
107            let event = SaTokenEvent::logout(info.login_id, token.as_str())
108                .with_login_type(&info.login_type);
109            self.event_bus.publish(event).await;
110        }
111        
112        Ok(())
113    }
114    
115    /// 根据登录 ID 登出所有 token
116    pub async fn logout_by_login_id(&self, _login_id: &str) -> SaTokenResult<()> {
117        // TODO: 实现根据登录 ID 查找所有 token 并删除
118        // 这需要维护 login_id -> tokens 的映射
119        Ok(())
120    }
121    
122    /// 获取 token 信息
123    pub async fn get_token_info(&self, token: &TokenValue) -> SaTokenResult<TokenInfo> {
124        let key = format!("sa:token:{}", token.as_str());
125        let value = self.storage.get(&key).await
126            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
127            .ok_or(SaTokenError::TokenNotFound)?;
128        
129        let token_info: TokenInfo = serde_json::from_str(&value)
130            .map_err(|e| SaTokenError::SerializationError(e))?;
131        
132        // 检查是否过期
133        if token_info.is_expired() {
134            // 删除过期的 token
135            self.logout(token).await?;
136            return Err(SaTokenError::TokenExpired);
137        }
138        
139        // 如果开启了自动续签,则自动续签
140        // 注意:为了避免递归调用 get_token_info,这里直接更新过期时间
141        if self.config.auto_renew {
142            let renew_timeout = if self.config.active_timeout > 0 {
143                self.config.active_timeout
144            } else {
145                self.config.timeout
146            };
147            
148            // 直接续签(不递归调用 get_token_info)
149            let _ = self.renew_timeout_internal(token, renew_timeout, &token_info).await;
150        }
151        
152        Ok(token_info)
153    }
154    
155    /// 检查 token 是否有效
156    pub async fn is_valid(&self, token: &TokenValue) -> bool {
157        self.get_token_info(token).await.is_ok()
158    }
159    
160    /// 获取 session
161    pub async fn get_session(&self, login_id: &str) -> SaTokenResult<SaSession> {
162        let key = format!("sa:session:{}", login_id);
163        let value = self.storage.get(&key).await
164            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
165        
166        if let Some(value) = value {
167            let session: SaSession = serde_json::from_str(&value)
168                .map_err(|e| SaTokenError::SerializationError(e))?;
169            Ok(session)
170        } else {
171            Ok(SaSession::new(login_id))
172        }
173    }
174    
175    /// 保存 session
176    pub async fn save_session(&self, session: &SaSession) -> SaTokenResult<()> {
177        let key = format!("sa:session:{}", session.id);
178        let value = serde_json::to_string(session)
179            .map_err(|e| SaTokenError::SerializationError(e))?;
180        
181        self.storage.set(&key, &value, None).await
182            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
183        
184        Ok(())
185    }
186    
187    /// 删除 session
188    pub async fn delete_session(&self, login_id: &str) -> SaTokenResult<()> {
189        let key = format!("sa:session:{}", login_id);
190        self.storage.delete(&key).await
191            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
192        Ok(())
193    }
194    
195    /// 续期 token(重置过期时间)
196    pub async fn renew_timeout(
197        &self,
198        token: &TokenValue,
199        timeout_seconds: i64,
200    ) -> SaTokenResult<()> {
201        let token_info = self.get_token_info(token).await?;
202        self.renew_timeout_internal(token, timeout_seconds, &token_info).await
203    }
204    
205    /// 内部续期方法(避免递归调用 get_token_info)
206    async fn renew_timeout_internal(
207        &self,
208        token: &TokenValue,
209        timeout_seconds: i64,
210        token_info: &TokenInfo,
211    ) -> SaTokenResult<()> {
212        let mut new_token_info = token_info.clone();
213        
214        // 设置新的过期时间
215        use chrono::{Utc, Duration};
216        let new_expire_time = Utc::now() + Duration::seconds(timeout_seconds);
217        new_token_info.expire_time = Some(new_expire_time);
218        
219        // 保存更新后的 token 信息
220        let key = format!("sa:token:{}", token.as_str());
221        let value = serde_json::to_string(&new_token_info)
222            .map_err(|e| SaTokenError::SerializationError(e))?;
223        
224        let timeout = std::time::Duration::from_secs(timeout_seconds as u64);
225        self.storage.set(&key, &value, Some(timeout)).await
226            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
227        
228        Ok(())
229    }
230    
231    /// 踢人下线
232    pub async fn kick_out(&self, login_id: &str) -> SaTokenResult<()> {
233        // 先获取 token,用于触发事件
234        let token_result = self.storage.get(&format!("sa:login:token:{}", login_id)).await;
235        
236        self.logout_by_login_id(login_id).await?;
237        self.delete_session(login_id).await?;
238        
239        // 触发踢出下线事件
240        if let Ok(Some(token_str)) = token_result {
241            let event = SaTokenEvent::kick_out(login_id, token_str);
242            self.event_bus.publish(event).await;
243        }
244        
245        Ok(())
246    }
247}