Skip to main content

sa_token_core/
disable.rs

1// Author: 金书记
2//
3//! 账号/服务封禁(对齐 Java StpLogic.disable/checkDisable)
4
5use std::time::Duration;
6
7use crate::error::{SaTokenError, SaTokenResult};
8use crate::manager::SaTokenManager;
9
10/// 默认封禁服务标识(对齐 Java `DEFAULT_DISABLE_SERVICE`)
11pub const DEFAULT_DISABLE_SERVICE: &str = "login";
12
13/// 最低封禁等级(对齐 Java `MIN_DISABLE_LEVEL`)
14pub const MIN_DISABLE_LEVEL: i32 = 1;
15
16/// 未封禁时的等级返回值(对齐 Java `NOT_DISABLE_LEVEL`)
17pub const NOT_DISABLE_LEVEL: i32 = -2;
18
19/// 默认写入封禁等级(对齐 Java `DEFAULT_DISABLE_LEVEL`)
20pub const DEFAULT_DISABLE_LEVEL: i32 = 1;
21
22impl SaTokenManager {
23    fn disable_key(&self, login_id: &str, service: &str) -> String {
24        self.config.make_key("disable:", &format!("{}:{}", login_id, service))
25    }
26
27    /// 封禁指定账号的指定服务及等级
28    ///
29    /// `time` 单位为秒,`-1` 表示永久封禁。
30    pub async fn disable_level(
31        &self,
32        login_id: &str,
33        service: &str,
34        level: i32,
35        time: i64,
36    ) -> SaTokenResult<()> {
37        if login_id.trim().is_empty() {
38            return Err(SaTokenError::ConfigError(
39                "login_id is required for disable".to_string(),
40            ));
41        }
42        if service.trim().is_empty() {
43            return Err(SaTokenError::ConfigError(
44                "service is required for disable".to_string(),
45            ));
46        }
47        if level < MIN_DISABLE_LEVEL && level != 0 {
48            return Err(SaTokenError::ConfigError(format!(
49                "disable level must be >= {} (0 allowed)",
50                MIN_DISABLE_LEVEL
51            )));
52        }
53
54        let ttl = if time < 0 {
55            None
56        } else {
57            Some(Duration::from_secs(time as u64))
58        };
59
60        self.storage
61            .set(
62                &self.disable_key(login_id, service),
63                &level.to_string(),
64                ttl,
65            )
66            .await
67            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
68
69        let event = crate::event::SaTokenEvent::banned(login_id);
70        self.event_bus.publish(event).await;
71
72        Ok(())
73    }
74
75    /// 封禁指定账号(默认服务 `login`、默认等级)
76    pub async fn disable(&self, login_id: &str, time: i64) -> SaTokenResult<()> {
77        self.disable_level(
78            login_id,
79            DEFAULT_DISABLE_SERVICE,
80            DEFAULT_DISABLE_LEVEL,
81            time,
82        )
83        .await
84    }
85
86    /// 获取封禁等级;未封禁返回 [`NOT_DISABLE_LEVEL`]
87    pub async fn get_disable_level(&self, login_id: &str, service: &str) -> SaTokenResult<i32> {
88        let key = self.disable_key(login_id, service);
89        let value = self
90            .storage
91            .get(&key)
92            .await
93            .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
94
95        if let Some(v) = value {
96            return v.parse::<i32>().map_err(|_| {
97                SaTokenError::StorageError(format!("invalid disable level for key {}", key))
98            });
99        }
100
101        if let Some(iface) = &self.stp_interface {
102            if let Some(level) = iface.is_disabled(login_id, service).await? {
103                return Ok(level);
104            }
105        }
106
107        Ok(NOT_DISABLE_LEVEL)
108    }
109
110    /// 是否已被封禁到指定等级(含更高等级)
111    pub async fn is_disable_level(
112        &self,
113        login_id: &str,
114        service: &str,
115        level: i32,
116    ) -> SaTokenResult<bool> {
117        let disable_level = self.get_disable_level(login_id, service).await?;
118        if disable_level == NOT_DISABLE_LEVEL {
119            return Ok(false);
120        }
121        Ok(disable_level >= level)
122    }
123
124    /// 校验封禁;若等级达到阈值则抛出 [`SaTokenError::DisableService`]
125    pub async fn check_disable_level(
126        &self,
127        login_id: &str,
128        service: &str,
129        level: i32,
130    ) -> SaTokenResult<()> {
131        let disable_level = self.get_disable_level(login_id, service).await?;
132        if disable_level == NOT_DISABLE_LEVEL {
133            return Ok(());
134        }
135        if disable_level >= level {
136            return Err(SaTokenError::AccountBanned(format!(
137                "service={} level={}",
138                service, disable_level
139            )));
140        }
141        Ok(())
142    }
143
144    /// 校验多个服务的封禁(全部通过才算通过)
145    pub async fn check_disable_services(
146        &self,
147        login_id: &str,
148        services: &[&str],
149        level: i32,
150    ) -> SaTokenResult<()> {
151        for service in services {
152            self.check_disable_level(login_id, service, level).await?;
153        }
154        Ok(())
155    }
156
157    /// 解封指定服务
158    pub async fn untie_disable(&self, login_id: &str, service: &str) -> SaTokenResult<()> {
159        self.storage
160            .delete(&self.disable_key(login_id, service))
161            .await
162            .map_err(|e| SaTokenError::StorageError(e.to_string()))
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::config::SaTokenConfig;
170    use sa_token_storage_memory::MemoryStorage;
171    use std::sync::Arc;
172
173    fn manager() -> SaTokenManager {
174        SaTokenManager::new(
175            Arc::new(MemoryStorage::new()),
176            SaTokenConfig::default(),
177        )
178    }
179
180    #[tokio::test]
181    async fn disable_and_check_level() {
182        let mgr = manager();
183        mgr.disable_level("u1", "login", 2, 60).await.unwrap();
184        assert!(mgr.is_disable_level("u1", "login", 1).await.unwrap());
185        assert!(mgr.is_disable_level("u1", "login", 2).await.unwrap());
186        assert!(!mgr.is_disable_level("u1", "login", 3).await.unwrap());
187        assert!(mgr.check_disable_level("u1", "login", 2).await.is_err());
188        mgr.untie_disable("u1", "login").await.unwrap();
189        assert_eq!(
190            mgr.get_disable_level("u1", "login").await.unwrap(),
191            NOT_DISABLE_LEVEL
192        );
193    }
194}