1use std::time::Duration;
6
7use crate::error::{SaTokenError, SaTokenResult};
8use crate::manager::SaTokenManager;
9
10pub const DEFAULT_DISABLE_SERVICE: &str = "login";
12
13pub const MIN_DISABLE_LEVEL: i32 = 1;
15
16pub const NOT_DISABLE_LEVEL: i32 = -2;
18
19pub const DEFAULT_DISABLE_LEVEL: i32 = 1;
21
22impl SaTokenManager {
23 fn disable_key(&self, login_id: &str, service: &str) -> String {
24 self.config.make_key("disable:", &format!("{}:{}", login_id, service))
25 }
26
27 pub async fn disable_level(
31 &self,
32 login_id: &str,
33 service: &str,
34 level: i32,
35 time: i64,
36 ) -> SaTokenResult<()> {
37 if login_id.trim().is_empty() {
38 return Err(SaTokenError::ConfigError(
39 "login_id is required for disable".to_string(),
40 ));
41 }
42 if service.trim().is_empty() {
43 return Err(SaTokenError::ConfigError(
44 "service is required for disable".to_string(),
45 ));
46 }
47 if level < MIN_DISABLE_LEVEL && level != 0 {
48 return Err(SaTokenError::ConfigError(format!(
49 "disable level must be >= {} (0 allowed)",
50 MIN_DISABLE_LEVEL
51 )));
52 }
53
54 let ttl = if time < 0 {
55 None
56 } else {
57 Some(Duration::from_secs(time as u64))
58 };
59
60 self.storage
61 .set(
62 &self.disable_key(login_id, service),
63 &level.to_string(),
64 ttl,
65 )
66 .await
67 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
68
69 let event = crate::event::SaTokenEvent::banned(login_id);
70 self.event_bus.publish(event).await;
71
72 Ok(())
73 }
74
75 pub async fn disable(&self, login_id: &str, time: i64) -> SaTokenResult<()> {
77 self.disable_level(
78 login_id,
79 DEFAULT_DISABLE_SERVICE,
80 DEFAULT_DISABLE_LEVEL,
81 time,
82 )
83 .await
84 }
85
86 pub async fn get_disable_level(&self, login_id: &str, service: &str) -> SaTokenResult<i32> {
88 let key = self.disable_key(login_id, service);
89 let value = self
90 .storage
91 .get(&key)
92 .await
93 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
94
95 if let Some(v) = value {
96 return v.parse::<i32>().map_err(|_| {
97 SaTokenError::StorageError(format!("invalid disable level for key {}", key))
98 });
99 }
100
101 if let Some(iface) = &self.stp_interface {
102 if let Some(level) = iface.is_disabled(login_id, service).await? {
103 return Ok(level);
104 }
105 }
106
107 Ok(NOT_DISABLE_LEVEL)
108 }
109
110 pub async fn is_disable_level(
112 &self,
113 login_id: &str,
114 service: &str,
115 level: i32,
116 ) -> SaTokenResult<bool> {
117 let disable_level = self.get_disable_level(login_id, service).await?;
118 if disable_level == NOT_DISABLE_LEVEL {
119 return Ok(false);
120 }
121 Ok(disable_level >= level)
122 }
123
124 pub async fn check_disable_level(
126 &self,
127 login_id: &str,
128 service: &str,
129 level: i32,
130 ) -> SaTokenResult<()> {
131 let disable_level = self.get_disable_level(login_id, service).await?;
132 if disable_level == NOT_DISABLE_LEVEL {
133 return Ok(());
134 }
135 if disable_level >= level {
136 return Err(SaTokenError::AccountBanned(format!(
137 "service={} level={}",
138 service, disable_level
139 )));
140 }
141 Ok(())
142 }
143
144 pub async fn check_disable_services(
146 &self,
147 login_id: &str,
148 services: &[&str],
149 level: i32,
150 ) -> SaTokenResult<()> {
151 for service in services {
152 self.check_disable_level(login_id, service, level).await?;
153 }
154 Ok(())
155 }
156
157 pub async fn untie_disable(&self, login_id: &str, service: &str) -> SaTokenResult<()> {
159 self.storage
160 .delete(&self.disable_key(login_id, service))
161 .await
162 .map_err(|e| SaTokenError::StorageError(e.to_string()))
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::config::SaTokenConfig;
170 use sa_token_storage_memory::MemoryStorage;
171 use std::sync::Arc;
172
173 fn manager() -> SaTokenManager {
174 SaTokenManager::new(
175 Arc::new(MemoryStorage::new()),
176 SaTokenConfig::default(),
177 )
178 }
179
180 #[tokio::test]
181 async fn disable_and_check_level() {
182 let mgr = manager();
183 mgr.disable_level("u1", "login", 2, 60).await.unwrap();
184 assert!(mgr.is_disable_level("u1", "login", 1).await.unwrap());
185 assert!(mgr.is_disable_level("u1", "login", 2).await.unwrap());
186 assert!(!mgr.is_disable_level("u1", "login", 3).await.unwrap());
187 assert!(mgr.check_disable_level("u1", "login", 2).await.is_err());
188 mgr.untie_disable("u1", "login").await.unwrap();
189 assert_eq!(
190 mgr.get_disable_level("u1", "login").await.unwrap(),
191 NOT_DISABLE_LEVEL
192 );
193 }
194}