1use dashmap::{mapref::one::RefMut, DashMap};
2use parking_lot::Mutex;
3use serde::{Deserialize, Serialize};
4use std::{
5 collections::BTreeMap,
6 fmt::Display,
7 net::{IpAddr, SocketAddr},
8 sync::Arc,
9};
10use tokio::{
11 sync::Notify,
12 time::{Duration, Instant},
13};
14use tokio_util::sync::CancellationToken;
15use tracing::{debug, warn};
16
17use crate::{ConfigManager, Error, Result};
18
19#[derive(Hash, PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
20pub enum Key {
21 IP(IpAddr),
22 Socket(SocketAddr),
23 Account(String),
25 Worker(String),
27}
28
29impl Display for Key {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Key::IP(ip) => write!(f, "IP: {ip}"),
33 Key::Socket(socket) => write!(f, "Socket: {socket}"),
34 Key::Account(account) => write!(f, "Account: {account}"),
35 Key::Worker(worker) => write!(f, "Worker: {worker}"),
37 }
38 }
39}
40
41impl From<SocketAddr> for Key {
42 fn from(value: SocketAddr) -> Self {
43 Key::Socket(value)
44 }
45}
46
47impl From<IpAddr> for Key {
48 fn from(value: IpAddr) -> Self {
49 Key::IP(value)
50 }
51}
52
53#[derive(Debug)]
57struct Entry {
58 id: u64,
60
61 data: BanInfo,
63
64 expires_at: Instant,
67}
68
69#[derive(Serialize, Clone, Debug)]
70pub struct BanInfo {
71 pub address: Key,
72 pub score: u64,
73}
74
75#[derive(Clone)]
77pub struct BanManager {
78 pub(crate) shared: Arc<Shared>,
79 config: ConfigManager,
80}
81
82pub(crate) struct Shared {
83 pub(crate) state: Mutex<State>,
84 pub(crate) cancel_token: CancellationToken,
85 background_task: Notify,
86 temp_bans: Arc<DashMap<Key, Entry>>,
87}
88
89impl Shared {
90 fn purge_expired_keys(&self) -> Option<Instant> {
93 if self.cancel_token.is_cancelled() {
94 return None;
97 }
98
99 let mut state = self.state.lock();
100
101 let now = Instant::now();
110
111 while let Some((&(when, id), key)) = state.expirations.iter().next() {
112 if when > now {
113 return Some(when);
116 }
117
118 self.temp_bans.remove(key);
120 state.expirations.remove(&(when, id));
121 }
122
123 None
124 }
125}
126
127async fn purge_expired_tasks(shared: Arc<Shared>) {
132 while !shared.cancel_token.is_cancelled() {
134 if let Some(when) = shared.purge_expired_keys() {
138 tokio::select! {
143 () = tokio::time::sleep_until(when) => {}
144 () = shared.background_task.notified() => {}
145 }
146 } else {
147 shared.background_task.notified().await;
150 }
151 }
152
153 debug!("Purge background task shut down");
155}
156
157#[derive(Default)]
158pub(crate) struct State {
159 expirations: BTreeMap<(Instant, u64), Key>,
171
172 next_id: u64,
175}
176
177impl State {
178 fn next_expiration(&self) -> Option<Instant> {
179 self.expirations
180 .keys()
181 .next()
182 .map(|expiration| expiration.0)
183 }
184}
185
186impl BanManager {
193 pub fn new(config: ConfigManager, cancel_token: CancellationToken) -> Self {
194 let shared = Arc::new(Shared {
195 state: Mutex::new(State::default()),
196 temp_bans: Arc::new(DashMap::new()),
197 background_task: Notify::new(),
198 cancel_token,
199 });
200
201 tokio::spawn(purge_expired_tasks(shared.clone()));
202 BanManager { shared, config }
204 }
205
206 pub fn check_banned<T: Into<Key>>(&self, key: T) -> Result<()> {
207 let key = key.into();
208 if self.shared.temp_bans.contains_key(&key) {
209 Err(Error::ConnectionBanned(key))
210 } else {
211 Ok(())
212 }
213 }
214
215 pub fn add_ban<T: Into<Key>>(&self, key: T) {
220 self.add_ban_raw(&key.into(), 10, self.config.default_ban_duration());
222 }
223
224 fn add_ban_raw(&self, key: &Key, score: u64, dur: Duration) {
227 let mut state = self.shared.state.lock();
235
236 let id = state.next_id;
239 state.next_id += 1;
240
241 let expires_at = Instant::now() + dur;
242
243 let notify = state
247 .next_expiration()
248 .map_or(true, |expiration| expiration > expires_at);
249
250 state.expirations.insert((expires_at, id), key.clone());
252
253 drop(state);
256
257 if let Some(entry) = self.shared.temp_bans.get_mut(key) {
258 let mut state = self.shared.state.lock();
261 state.expirations.remove(&(entry.expires_at, entry.id));
263
264 let new_score = entry.data.score + score;
265
266 let mut new_entry = RefMut::map(entry, |t| t);
267
268 new_entry.data.score = new_score;
270 new_entry.id = id;
271
272 drop(state);
273 } else {
274 let entry = Entry {
275 id,
276 data: BanInfo {
277 address: key.clone(),
278 score,
279 },
280 expires_at,
281 };
282
283 self.shared.temp_bans.insert(key.clone(), entry);
284 }
285
286 if notify {
287 self.shared.background_task.notify_one();
288 }
289 }
290
291 pub fn remove_ban<T: Into<Key>>(&self, key: T) -> Option<BanInfo> {
292 let mut state = self.shared.state.lock();
293 let key = key.into();
294
295 if let Some((_, entry)) = self.shared.temp_bans.remove(&key) {
296 warn!("Manually unbanning: {key}. Make sure you know what you are doing!");
297 state.expirations.remove(&(entry.expires_at, entry.id));
298 return Some(entry.data);
299 }
300
301 None
302 }
303
304 pub fn temp_bans(&self) -> Vec<BanInfo> {
307 self.shared
308 .temp_bans
309 .iter()
310 .map(|ref_multi| ref_multi.value().data.clone())
311 .collect()
312 }
313}
314
315impl Drop for BanManager {
316 fn drop(&mut self) {
317 self.shared.cancel_token.cancel();
318 self.shared.background_task.notify_one();
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use std::str::FromStr;
325
326 use crate::Config;
327
328 use super::*;
330 use tokio_test::{assert_err, assert_ok};
331
332 #[cfg_attr(coverage_nightly, coverage(off))]
333 #[tokio::test]
334 async fn single_ban_expires() -> anyhow::Result<()> {
335 let cancel_token = CancellationToken::new();
336 let mut config = Config::default();
337 config.bans.default_ban_duration = ms(1);
338 let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token);
339
340 let bad_miner: SocketAddr = assert_ok!("163.244.101.203:3841".parse());
341
342 ban_manager.add_ban(bad_miner);
343
344 let temp_bans = ban_manager.temp_bans();
345
346 assert_eq!(temp_bans.len(), 1);
347
348 tokio::time::sleep(ms(10)).await;
349
350 let temp_bans = ban_manager.temp_bans();
351
352 assert_eq!(temp_bans.len(), 0);
353
354 Ok(())
355 }
356
357 #[cfg_attr(coverage_nightly, coverage(off))]
358 #[tokio::test]
359 async fn ban_extended() -> anyhow::Result<()> {
360 let cancel_token = CancellationToken::new();
361 let mut config = Config::default();
362 config.bans.default_ban_duration = Duration::from_secs(100);
363 let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token);
364
365 let bad_miner: SocketAddr = assert_ok!("163.244.101.203:3841".parse());
368
369 ban_manager.add_ban(bad_miner);
370
371 let temp_bans = ban_manager.temp_bans();
372
373 assert_eq!(temp_bans.len(), 1);
374 assert_eq!(temp_bans[0].score, 10);
376
377 ban_manager.add_ban(bad_miner);
382
383 let temp_bans = ban_manager.temp_bans();
384
385 assert_eq!(temp_bans.len(), 1);
386 assert_eq!(temp_bans[0].score, 20);
388
389 tokio::time::sleep(ms(40)).await;
390
391 ban_manager.remove_ban(bad_miner);
392 let temp_bans = ban_manager.temp_bans();
393
394 assert_eq!(temp_bans.len(), 0);
395
396 Ok(())
397 }
398
399 fn ms(n: u64) -> Duration {
400 Duration::from_millis(n)
401 }
402
403 #[cfg_attr(coverage_nightly, coverage(off))]
404 #[tokio::test]
405 async fn graceful_shutdown() -> anyhow::Result<()> {
406 let cancel_token = CancellationToken::new();
407 let mut config = Config::default();
408 config.bans.default_ban_duration = ms(100);
409 let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token.child_token());
410
411 let addr = assert_ok!(SocketAddr::from_str("163.244.101.203:3821"));
412
413 ban_manager.add_ban(addr);
414
415 assert_err!(ban_manager.check_banned(addr));
416
417 cancel_token.cancel();
418
419 tokio::time::sleep(ms(200)).await;
420
421 assert_err!(ban_manager.check_banned(addr));
422
423 Ok(())
424 }
425}