sbd_server/
ip_rate.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tokio::task::JoinHandle;
4
5type Map = HashMap<Arc<std::net::Ipv6Addr>, u64>;
6
7/// Rate limit connections by IP address.
8pub struct IpRate {
9    origin: tokio::time::Instant,
10    map: Arc<Mutex<Map>>,
11    disabled: bool,
12    limit: u64,
13    burst: u64,
14    ip_deny: crate::ip_deny::IpDeny,
15}
16
17impl IpRate {
18    /// Construct a new IpRate limit instance.
19    pub fn new(config: Arc<crate::Config>) -> Self {
20        Self {
21            origin: tokio::time::Instant::now(),
22            map: Arc::new(Mutex::new(HashMap::new())),
23            disabled: config.disable_rate_limiting,
24            limit: config.limit_ip_byte_nanos() as u64,
25            burst: config.limit_ip_byte_burst as u64
26                * config.limit_ip_byte_nanos() as u64,
27            ip_deny: crate::ip_deny::IpDeny::new(config),
28        }
29    }
30
31    /// Prune entries that have tracked backwards 10s or more.
32    /// The 10s just prevents hashtable thrashing if a connection
33    /// is using significantly less than its rate limit.
34    /// This is why the keepalive interval is 5 seconds and
35    /// connections are closed after 10 seconds.
36    pub fn prune(&self) {
37        let now = self.origin.elapsed().as_nanos() as u64;
38        self.map.lock().unwrap().retain(|_, cur| {
39            if now <= *cur {
40                true
41            } else {
42                // examples using seconds:
43                // now:100,cur:120 100-120=-20<10  true=keep
44                // now:100,cur:100 100-100=0<10    true=keep
45                // now:100,cur:80   100-80=20<10  false=prune
46                now - *cur < 10_000_000_000
47            }
48        });
49    }
50
51    /// Return true if this ip is blocked.
52    pub async fn is_blocked(&self, ip: &Arc<std::net::Ipv6Addr>) -> bool {
53        self.ip_deny.is_blocked(ip).await
54    }
55
56    /// Return true if we are not over the rate limit.
57    pub async fn is_ok(
58        &self,
59        ip: &Arc<std::net::Ipv6Addr>,
60        bytes: usize,
61    ) -> bool {
62        if self.disabled {
63            return true;
64        }
65
66        // multiply by our rate allowed per byte
67        let rate_add = bytes as u64 * self.limit;
68
69        // get now
70        let now = self.origin.elapsed().as_nanos() as u64;
71
72        let is_ok = {
73            // lock the map mutex
74            let mut lock = self.map.lock().unwrap();
75
76            // get the entry (default to now)
77            let e = lock.entry(ip.clone()).or_insert(now);
78
79            // if we've already used time greater than now use that,
80            // otherwise consider we're starting from scratch
81            let cur = std::cmp::max(*e, now) + rate_add;
82
83            // update the map with the current limit
84            *e = cur;
85
86            // subtract now back out to see if we're greater than our burst
87            cur - now <= self.burst
88        };
89
90        if !is_ok {
91            self.ip_deny.block(ip).await;
92        }
93
94        is_ok
95    }
96}
97
98/// Spawn a Tokio task to prune the IpRate map.
99pub fn spawn_prune_task(ip_rate: Arc<IpRate>) -> JoinHandle<()> {
100    let ip_rate = Arc::downgrade(&ip_rate);
101    tokio::task::spawn(async move {
102        loop {
103            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
104            if let Some(ip_rate) = ip_rate.upgrade() {
105                ip_rate.prune();
106            } else {
107                break;
108            }
109        }
110    })
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    fn test_new(limit: u64, burst: u64) -> IpRate {
118        IpRate {
119            origin: tokio::time::Instant::now(),
120            map: Arc::new(Mutex::new(HashMap::new())),
121            disabled: false,
122            limit,
123            burst,
124            ip_deny: crate::ip_deny::IpDeny::new(Arc::new(
125                crate::Config::default(),
126            )),
127        }
128    }
129
130    #[tokio::test(flavor = "current_thread", start_paused = true)]
131    async fn check_one_to_one() {
132        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
133
134        let rate = test_new(1, 1);
135
136        for _ in 0..10 {
137            // should always be ok when advancing with time
138            tokio::time::advance(std::time::Duration::from_nanos(1)).await;
139            assert!(rate.is_ok(&addr1, 1).await);
140        }
141
142        // but one more without a time advance fails
143        assert!(!rate.is_ok(&addr1, 1).await);
144
145        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
146
147        // make sure prune doesn't prune it yet
148        rate.prune();
149        assert_eq!(1, rate.map.lock().unwrap().len());
150
151        tokio::time::advance(std::time::Duration::from_secs(10)).await;
152
153        // make sure prune doesn't prune it yet
154        rate.prune();
155        assert_eq!(1, rate.map.lock().unwrap().len());
156
157        // but one more should do it
158        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
159        rate.prune();
160        assert_eq!(0, rate.map.lock().unwrap().len());
161    }
162
163    #[tokio::test(flavor = "current_thread", start_paused = true)]
164    async fn check_burst() {
165        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
166
167        let rate = test_new(1, 5);
168
169        for _ in 0..5 {
170            assert!(rate.is_ok(&addr1, 1).await);
171        }
172
173        assert!(!rate.is_ok(&addr1, 1).await);
174
175        tokio::time::advance(std::time::Duration::from_nanos(2)).await;
176        assert!(rate.is_ok(&addr1, 1).await);
177
178        tokio::time::advance(std::time::Duration::from_secs(10)).await;
179        tokio::time::advance(std::time::Duration::from_nanos(4)).await;
180
181        rate.prune();
182        assert_eq!(1, rate.map.lock().unwrap().len());
183
184        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
185
186        rate.prune();
187        assert_eq!(0, rate.map.lock().unwrap().len());
188    }
189
190    #[tokio::test(flavor = "current_thread", start_paused = true)]
191    async fn check_limit_mult() {
192        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
193
194        let rate = test_new(3, 13);
195
196        assert!(rate.is_ok(&addr1, 2).await);
197        assert!(rate.is_ok(&addr1, 2).await);
198        assert!(!rate.is_ok(&addr1, 2).await);
199
200        tokio::time::advance(std::time::Duration::from_secs(10)).await;
201
202        assert!(rate.is_ok(&addr1, 2).await);
203        assert!(rate.is_ok(&addr1, 2).await);
204        assert!(!rate.is_ok(&addr1, 2).await);
205    }
206}