stratum_server/
ban_manager.rs

1use dashmap::{mapref::one::RefMut, DashMap};
2use parking_lot::Mutex;
3use serde::{Deserialize, Serialize};
4use std::{
5    collections::BTreeMap,
6    fmt::Display,
7    net::{IpAddr, SocketAddr},
8    sync::Arc,
9};
10use tokio::{
11    sync::Notify,
12    time::{Duration, Instant},
13};
14use tokio_util::sync::CancellationToken;
15use tracing::{debug, warn};
16
17use crate::{ConfigManager, Error, Result};
18
19#[derive(Hash, PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
20pub enum Key {
21    IP(IpAddr),
22    Socket(SocketAddr),
23    // Account(Username)
24    Account(String),
25    // Account(Username)
26    Worker(String),
27}
28
29impl Display for Key {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Key::IP(ip) => write!(f, "IP: {ip}"),
33            Key::Socket(socket) => write!(f, "Socket: {socket}"),
34            Key::Account(account) => write!(f, "Account: {account}"),
35            //@todo do this better when its account, workername
36            Key::Worker(worker) => write!(f, "Worker: {worker}"),
37        }
38    }
39}
40
41impl From<SocketAddr> for Key {
42    fn from(value: SocketAddr) -> Self {
43        Key::Socket(value)
44    }
45}
46
47impl From<IpAddr> for Key {
48    fn from(value: IpAddr) -> Self {
49        Key::IP(value)
50    }
51}
52
53//@todo implement froms here.
54
55/// A wrapping around entries that adds the link to the entry's expiration, via a `delay_queue` key.
56#[derive(Debug)]
57struct Entry {
58    /// Uniquely identifies this entry.
59    id: u64,
60
61    /// Stored data
62    data: BanInfo,
63
64    /// Instant at which the entry expires and should be removed from the
65    /// database.
66    expires_at: Instant,
67}
68
69#[derive(Serialize, Clone, Debug)]
70pub struct BanInfo {
71    pub address: Key,
72    pub score: u64,
73}
74
75//@todo perma bans
76#[derive(Clone)]
77pub struct BanManager {
78    pub(crate) shared: Arc<Shared>,
79    config: ConfigManager,
80}
81
82pub(crate) struct Shared {
83    pub(crate) state: Mutex<State>,
84    pub(crate) cancel_token: CancellationToken,
85    background_task: Notify,
86    temp_bans: Arc<DashMap<Key, Entry>>,
87}
88
89impl Shared {
90    /// Purge all expired keys and return the `Instant` at which the **next**
91    /// key will expire. The background task will sleep until this instant.
92    fn purge_expired_keys(&self) -> Option<Instant> {
93        if self.cancel_token.is_cancelled() {
94            // The database is shutting down. All handles to the shared state
95            // have dropped. The background task should exit.
96            return None;
97        }
98
99        let mut state = self.state.lock();
100
101        // This is needed to make the borrow checker happy. In short, `lock()`
102        // returns a `MutexGuard` and not a `&mut State`. The borrow checker is
103        // not able to see "through" the mutex guard and determine that it is
104        // safe to access both `state.expirations` and `state.entries` mutably,
105        // so we get a "real" mutable reference to `State` outside of the loop.
106        // let state = &mut *state;
107
108        // Find all keys scheduled to expire **before** now.
109        let now = Instant::now();
110
111        while let Some((&(when, id), key)) = state.expirations.iter().next() {
112            if when > now {
113                // Done purging, `when` is the instant at which the next key
114                // expires. The worker task will wait until this instant.
115                return Some(when);
116            }
117
118            // The key expired, remove it
119            self.temp_bans.remove(key);
120            state.expirations.remove(&(when, id));
121        }
122
123        None
124    }
125}
126
127/// Routine executed by the background task.
128///
129/// Wait to be notified. On notification, purge any expired keys from the shared
130/// state handle. If `shutdown` is set, terminate the task.
131async fn purge_expired_tasks(shared: Arc<Shared>) {
132    // If the shutdown flag is set, then the task should exit.
133    while !shared.cancel_token.is_cancelled() {
134        // Purge all keys that are expired. The function returns the instant at
135        // which the **next** key will expire. The worker should wait until the
136        // instant has passed then purge again.
137        if let Some(when) = shared.purge_expired_keys() {
138            // Wait until the next key expires **or** until the background task
139            // is notified. If the task is notified, then it must reload its
140            // state as new keys have been set to expire early. This is done by
141            // looping.
142            tokio::select! {
143                () = tokio::time::sleep_until(when) => {}
144                () = shared.background_task.notified() => {}
145            }
146        } else {
147            // There are no keys expiring in the future. Wait until the task is
148            // notified.
149            shared.background_task.notified().await;
150        }
151    }
152
153    //@todo figure out why this triggers immediately
154    debug!("Purge background task shut down");
155}
156
157#[derive(Default)]
158pub(crate) struct State {
159    // pub(crate) expirations: DelayQueue<SocketAddr>,
160    /// Tracks key TTLs.
161    ///
162    /// A `BTreeMap` is used to maintain expirations sorted by when they expire.
163    /// This allows the background task to iterate this map to find the value
164    /// expiring next.
165    ///
166    /// While highly unlikely, it is possible for more than one expiration to be
167    /// created for the same instant. Because of this, the `Instant` is
168    /// insufficient for the key. A unique expiration identifier (`u64`) is used
169    /// to break these ties.
170    expirations: BTreeMap<(Instant, u64), Key>,
171
172    /// Identifier to use for the next expiration. Each expiration is associated
173    /// with a unique identifier. See above for why.
174    next_id: u64,
175}
176
177impl State {
178    fn next_expiration(&self) -> Option<Instant> {
179        self.expirations
180            .keys()
181            .next()
182            .map(|expiration| expiration.0)
183    }
184}
185
186//@todo 1. Add remove_ban, and feature gate it for the API -> That way we can unban miners manually
187//if we need to (or IP addresses).
188//2. Feature gate all of Ban Manager,
189//3. Allow for white_list or non-bannables
190//4. Allow for not just IPs to be banned, but usernames, and usnermae/workernames combinations.
191//- For the above can we switch to using an Enum as the hashmpa key
192impl BanManager {
193    pub fn new(config: ConfigManager, cancel_token: CancellationToken) -> Self {
194        let shared = Arc::new(Shared {
195            state: Mutex::new(State::default()),
196            temp_bans: Arc::new(DashMap::new()),
197            background_task: Notify::new(),
198            cancel_token,
199        });
200
201        tokio::spawn(purge_expired_tasks(shared.clone()));
202        // 1 hour
203        BanManager { shared, config }
204    }
205
206    pub fn check_banned<T: Into<Key>>(&self, key: T) -> Result<()> {
207        let key = key.into();
208        if self.shared.temp_bans.contains_key(&key) {
209            Err(Error::ConnectionBanned(key))
210        } else {
211            Ok(())
212        }
213    }
214
215    //@todo have a check function so that we A. don't ban loopback_etc, and B. don't ban
216    //whitelisted or our own IPs.
217
218    //@todo figure out what score means
219    pub fn add_ban<T: Into<Key>>(&self, key: T) {
220        // if self.config.current_config().bans.whitelisted_ips
221        self.add_ban_raw(&key.into(), 10, self.config.default_ban_duration());
222    }
223
224    //@todo add_ban generic that impls Into<Key>
225
226    fn add_ban_raw(&self, key: &Key, score: u64, dur: Duration) {
227        //@todo there may be other's we want to check for here.
228        //@todo we probably want to have an IP whitelist on here.
229        //@todo move these to ban socket or ip.
230        // if addr.ip().is_loopback() || addr.ip().is_unspecified() {
231        //     return;
232        // }
233
234        let mut state = self.shared.state.lock();
235
236        // Get and increment the next insertion ID. Guarded by the lock, this
237        // ensures a unique identifier is associated with each `set` operation.
238        let id = state.next_id;
239        state.next_id += 1;
240
241        let expires_at = Instant::now() + dur;
242
243        // Only notify the worker task if the newly inserted expiration is the
244        // **next** key to evict. In this case, the worker needs to be woken up
245        // to update its state.
246        let notify = state
247            .next_expiration()
248            .map_or(true, |expiration| expiration > expires_at);
249
250        // Track the expiration.
251        state.expirations.insert((expires_at, id), key.clone());
252
253        //@todo might make sense to drop state here, before jumping into Dashmap potential locking
254        //scenario.
255        drop(state);
256
257        if let Some(entry) = self.shared.temp_bans.get_mut(key) {
258            // let old_entry = entry;
259
260            let mut state = self.shared.state.lock();
261            //Remove the old expiration as we've set a new one.
262            state.expirations.remove(&(entry.expires_at, entry.id));
263
264            let new_score = entry.data.score + score;
265
266            let mut new_entry = RefMut::map(entry, |t| t);
267
268            //@todo test if this works.
269            new_entry.data.score = new_score;
270            new_entry.id = id;
271
272            drop(state);
273        } else {
274            let entry = Entry {
275                id,
276                data: BanInfo {
277                    address: key.clone(),
278                    score,
279                },
280                expires_at,
281            };
282
283            self.shared.temp_bans.insert(key.clone(), entry);
284        }
285
286        if notify {
287            self.shared.background_task.notify_one();
288        }
289    }
290
291    pub fn remove_ban<T: Into<Key>>(&self, key: T) -> Option<BanInfo> {
292        let mut state = self.shared.state.lock();
293        let key = key.into();
294
295        if let Some((_, entry)) = self.shared.temp_bans.remove(&key) {
296            warn!("Manually unbanning: {key}. Make sure you know what you are doing!");
297            state.expirations.remove(&(entry.expires_at, entry.id));
298            return Some(entry.data);
299        }
300
301        None
302    }
303
304    // #[cfg(feature = "api")]
305    /// Returns a vector of referencing all values in the map.
306    pub fn temp_bans(&self) -> Vec<BanInfo> {
307        self.shared
308            .temp_bans
309            .iter()
310            .map(|ref_multi| ref_multi.value().data.clone())
311            .collect()
312    }
313}
314
315impl Drop for BanManager {
316    fn drop(&mut self) {
317        self.shared.cancel_token.cancel();
318        self.shared.background_task.notify_one();
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use std::str::FromStr;
325
326    use crate::Config;
327
328    // Note this useful idiom: importing names from outer (for mod tests) scope.
329    use super::*;
330    use tokio_test::{assert_err, assert_ok};
331
332    #[cfg_attr(coverage_nightly, coverage(off))]
333    #[tokio::test]
334    async fn single_ban_expires() -> anyhow::Result<()> {
335        let cancel_token = CancellationToken::new();
336        let mut config = Config::default();
337        config.bans.default_ban_duration = ms(1);
338        let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token);
339
340        let bad_miner: SocketAddr = assert_ok!("163.244.101.203:3841".parse());
341
342        ban_manager.add_ban(bad_miner);
343
344        let temp_bans = ban_manager.temp_bans();
345
346        assert_eq!(temp_bans.len(), 1);
347
348        tokio::time::sleep(ms(10)).await;
349
350        let temp_bans = ban_manager.temp_bans();
351
352        assert_eq!(temp_bans.len(), 0);
353
354        Ok(())
355    }
356
357    #[cfg_attr(coverage_nightly, coverage(off))]
358    #[tokio::test]
359    async fn ban_extended() -> anyhow::Result<()> {
360        let cancel_token = CancellationToken::new();
361        let mut config = Config::default();
362        config.bans.default_ban_duration = Duration::from_secs(100);
363        let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token);
364
365        // tokio::time::pause();
366
367        let bad_miner: SocketAddr = assert_ok!("163.244.101.203:3841".parse());
368
369        ban_manager.add_ban(bad_miner);
370
371        let temp_bans = ban_manager.temp_bans();
372
373        assert_eq!(temp_bans.len(), 1);
374        // assert_eq!(temp_bans[0].address, bad_miner);
375        assert_eq!(temp_bans[0].score, 10);
376
377        // tokio::time::advance(ms(10)).await;
378        // tokio::time::sleep(ms(10)).await;
379        // tokio::time::pau
380
381        ban_manager.add_ban(bad_miner);
382
383        let temp_bans = ban_manager.temp_bans();
384
385        assert_eq!(temp_bans.len(), 1);
386        // assert_eq!(temp_bans[0].address, bad_miner);
387        assert_eq!(temp_bans[0].score, 20);
388
389        tokio::time::sleep(ms(40)).await;
390
391        ban_manager.remove_ban(bad_miner);
392        let temp_bans = ban_manager.temp_bans();
393
394        assert_eq!(temp_bans.len(), 0);
395
396        Ok(())
397    }
398
399    fn ms(n: u64) -> Duration {
400        Duration::from_millis(n)
401    }
402
403    #[cfg_attr(coverage_nightly, coverage(off))]
404    #[tokio::test]
405    async fn graceful_shutdown() -> anyhow::Result<()> {
406        let cancel_token = CancellationToken::new();
407        let mut config = Config::default();
408        config.bans.default_ban_duration = ms(100);
409        let ban_manager = BanManager::new(ConfigManager::new(config), cancel_token.child_token());
410
411        let addr = assert_ok!(SocketAddr::from_str("163.244.101.203:3821"));
412
413        ban_manager.add_ban(addr);
414
415        assert_err!(ban_manager.check_banned(addr));
416
417        cancel_token.cancel();
418
419        tokio::time::sleep(ms(200)).await;
420
421        assert_err!(ban_manager.check_banned(addr));
422
423        Ok(())
424    }
425}