sa_token_core/
manager.rs

1// Author: 金书记
2//
3//! Token 管理器 - sa-token 的核心入口
4
5use std::sync::Arc;
6use std::collections::HashMap;
7use chrono::{Duration, Utc};
8use tokio::sync::RwLock;
9use sa_token_adapter::storage::SaStorage;
10use crate::config::SaTokenConfig;
11use crate::error::{SaTokenError, SaTokenResult};
12use crate::token::{TokenInfo, TokenValue, TokenGenerator};
13use crate::session::SaSession;
14
15/// sa-token 管理器
16#[derive(Clone)]
17pub struct SaTokenManager {
18    pub(crate) storage: Arc<dyn SaStorage>,
19    pub config: SaTokenConfig,
20    /// 用户权限映射 user_id -> permissions
21    pub(crate) user_permissions: Arc<RwLock<HashMap<String, Vec<String>>>>,
22    /// 用户角色映射 user_id -> roles
23    pub(crate) user_roles: Arc<RwLock<HashMap<String, Vec<String>>>>,
24}
25
26impl SaTokenManager {
27    /// 创建新的管理器实例
28    pub fn new(storage: Arc<dyn SaStorage>, config: SaTokenConfig) -> Self {
29        Self { 
30            storage, 
31            config,
32            user_permissions: Arc::new(RwLock::new(HashMap::new())),
33            user_roles: Arc::new(RwLock::new(HashMap::new())),
34        }
35    }
36    
37    /// 登录:为指定账号创建 token
38    pub async fn login(&self, login_id: impl Into<String>) -> SaTokenResult<TokenValue> {
39        let login_id = login_id.into();
40        
41        // 生成 token
42        let token = TokenGenerator::generate(&self.config);
43        
44        // 创建 token 信息
45        let mut token_info = TokenInfo::new(token.clone(), login_id.clone());
46        token_info.login_type = "default".to_string();
47        
48        // 设置过期时间
49        if let Some(timeout) = self.config.timeout_duration() {
50            token_info.expire_time = Some(Utc::now() + Duration::from_std(timeout).unwrap());
51        }
52        
53        // 存储 token 信息
54        let key = format!("sa:token:{}", token.as_str());
55        let value = serde_json::to_string(&token_info)
56            .map_err(|e| SaTokenError::SerializationError(e))?;
57        
58        self.storage.set(&key, &value, self.config.timeout_duration()).await
59            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
60        
61        // 保存 login_id 到 token 的映射(用于根据 login_id 查找 token)
62        let login_token_key = format!("sa:login:token:{}", login_id);
63        self.storage.set(&login_token_key, token.as_str(), self.config.timeout_duration()).await
64            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
65        
66        // 如果不允许并发登录,踢掉之前的 token
67        if !self.config.is_concurrent {
68            self.logout_by_login_id(&login_id).await?;
69        }
70        
71        Ok(token)
72    }
73    
74    /// 登出:删除指定 token
75    pub async fn logout(&self, token: &TokenValue) -> SaTokenResult<()> {
76        let key = format!("sa:token:{}", token.as_str());
77        self.storage.delete(&key).await
78            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
79        Ok(())
80    }
81    
82    /// 根据登录 ID 登出所有 token
83    pub async fn logout_by_login_id(&self, _login_id: &str) -> SaTokenResult<()> {
84        // TODO: 实现根据登录 ID 查找所有 token 并删除
85        // 这需要维护 login_id -> tokens 的映射
86        Ok(())
87    }
88    
89    /// 获取 token 信息
90    pub async fn get_token_info(&self, token: &TokenValue) -> SaTokenResult<TokenInfo> {
91        let key = format!("sa:token:{}", token.as_str());
92        let value = self.storage.get(&key).await
93            .map_err(|e| SaTokenError::StorageError(e.to_string()))?
94            .ok_or(SaTokenError::TokenNotFound)?;
95        
96        let token_info: TokenInfo = serde_json::from_str(&value)
97            .map_err(|e| SaTokenError::SerializationError(e))?;
98        
99        // 检查是否过期
100        if token_info.is_expired() {
101            // 删除过期的 token
102            self.logout(token).await?;
103            return Err(SaTokenError::TokenExpired);
104        }
105        
106        // 如果开启了自动续签,则自动续签
107        // 注意:为了避免递归调用 get_token_info,这里直接更新过期时间
108        if self.config.auto_renew {
109            let renew_timeout = if self.config.active_timeout > 0 {
110                self.config.active_timeout
111            } else {
112                self.config.timeout
113            };
114            
115            // 直接续签(不递归调用 get_token_info)
116            let _ = self.renew_timeout_internal(token, renew_timeout, &token_info).await;
117        }
118        
119        Ok(token_info)
120    }
121    
122    /// 检查 token 是否有效
123    pub async fn is_valid(&self, token: &TokenValue) -> bool {
124        self.get_token_info(token).await.is_ok()
125    }
126    
127    /// 获取 session
128    pub async fn get_session(&self, login_id: &str) -> SaTokenResult<SaSession> {
129        let key = format!("sa:session:{}", login_id);
130        let value = self.storage.get(&key).await
131            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
132        
133        if let Some(value) = value {
134            let session: SaSession = serde_json::from_str(&value)
135                .map_err(|e| SaTokenError::SerializationError(e))?;
136            Ok(session)
137        } else {
138            Ok(SaSession::new(login_id))
139        }
140    }
141    
142    /// 保存 session
143    pub async fn save_session(&self, session: &SaSession) -> SaTokenResult<()> {
144        let key = format!("sa:session:{}", session.id);
145        let value = serde_json::to_string(session)
146            .map_err(|e| SaTokenError::SerializationError(e))?;
147        
148        self.storage.set(&key, &value, None).await
149            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
150        
151        Ok(())
152    }
153    
154    /// 删除 session
155    pub async fn delete_session(&self, login_id: &str) -> SaTokenResult<()> {
156        let key = format!("sa:session:{}", login_id);
157        self.storage.delete(&key).await
158            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
159        Ok(())
160    }
161    
162    /// 续期 token(重置过期时间)
163    pub async fn renew_timeout(
164        &self,
165        token: &TokenValue,
166        timeout_seconds: i64,
167    ) -> SaTokenResult<()> {
168        let token_info = self.get_token_info(token).await?;
169        self.renew_timeout_internal(token, timeout_seconds, &token_info).await
170    }
171    
172    /// 内部续期方法(避免递归调用 get_token_info)
173    async fn renew_timeout_internal(
174        &self,
175        token: &TokenValue,
176        timeout_seconds: i64,
177        token_info: &TokenInfo,
178    ) -> SaTokenResult<()> {
179        let mut new_token_info = token_info.clone();
180        
181        // 设置新的过期时间
182        use chrono::{Utc, Duration};
183        let new_expire_time = Utc::now() + Duration::seconds(timeout_seconds);
184        new_token_info.expire_time = Some(new_expire_time);
185        
186        // 保存更新后的 token 信息
187        let key = format!("sa:token:{}", token.as_str());
188        let value = serde_json::to_string(&new_token_info)
189            .map_err(|e| SaTokenError::SerializationError(e))?;
190        
191        let timeout = std::time::Duration::from_secs(timeout_seconds as u64);
192        self.storage.set(&key, &value, Some(timeout)).await
193            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
194        
195        Ok(())
196    }
197    
198    /// 踢人下线
199    pub async fn kick_out(&self, login_id: &str) -> SaTokenResult<()> {
200        self.logout_by_login_id(login_id).await?;
201        self.delete_session(login_id).await?;
202        Ok(())
203    }
204}