rustydht_lib/storage/
throttler.rs

1use std::net::{IpAddr, Ipv4Addr};
2use std::time::{Duration, Instant};
3
4use log::debug;
5
6type PacketCount = usize;
7
8#[derive(Clone, Copy)]
9struct ThrottlerRecord {
10    ip: IpAddr,
11    packets: PacketCount,
12    expiration: Instant,
13    creation_time: Instant,
14}
15
16impl Default for ThrottlerRecord {
17    fn default() -> Self {
18        let now = Instant::now();
19        ThrottlerRecord {
20            ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
21            packets: 0,
22            expiration: now,
23            creation_time: now,
24        }
25    }
26}
27
28impl ThrottlerRecord {
29    fn clear(&mut self) {
30        let now = Instant::now();
31        self.ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
32        self.packets = 0;
33        self.expiration = now;
34        self.creation_time = now;
35    }
36}
37
38pub struct Throttler<const NUM_RECORDS: usize> {
39    records: [ThrottlerRecord; NUM_RECORDS],
40
41    rate_limit: PacketCount,
42    period: Duration,
43    naughty_timeout: Duration,
44    max_tracking: Duration,
45}
46
47impl<const NUM_RECORDS: usize> Throttler<NUM_RECORDS> {
48    pub fn new(
49        rate_limit: PacketCount,
50        period: Duration,
51        naughty_timeout: Duration,
52        max_tracking: Duration,
53    ) -> Throttler<NUM_RECORDS> {
54        Throttler {
55            records: [ThrottlerRecord::default(); NUM_RECORDS],
56            rate_limit,
57            period,
58            naughty_timeout,
59            max_tracking,
60        }
61    }
62
63    /// Returns true if the provided IP is throttled.
64    pub fn check_throttle(
65        &mut self,
66        ip: IpAddr,
67        now: Option<Instant>,
68        count: Option<PacketCount>,
69    ) -> bool {
70        let now = match now {
71            Some(instant) => instant,
72            None => Instant::now(),
73        };
74
75        let mut found: Option<&mut ThrottlerRecord> = None;
76        let mut lamest: Option<&mut ThrottlerRecord> = None;
77        for record in &mut self.records {
78            // If record exists for this IP, use it
79            if record.ip == ip {
80                found = Some(record);
81                break;
82            }
83
84            // Keep track of the saddest/lamest record as we go
85            if let Some(lame) = &lamest {
86                if record.packets < lame.packets
87                    || (record.packets == lame.packets && record.expiration < lame.expiration)
88                {
89                    lamest = Some(record);
90                }
91            } else {
92                lamest = Some(record)
93            }
94        }
95
96        if let Some(found) = found {
97            // If this record has been around for longer than the max tracking time, don't use it and reset to blank
98            if let Some(since_creation) = now.checked_duration_since(found.creation_time) {
99                if since_creation > self.max_tracking {
100                    found.clear();
101                    return false;
102                }
103            }
104
105            if now < found.expiration {
106                found.packets = found
107                    .packets
108                    .checked_add(count.unwrap_or(1))
109                    .unwrap_or(PacketCount::MAX);
110            } else {
111                found.packets = count.unwrap_or(1);
112                found.expiration = now + self.period;
113            }
114            if found.packets > self.rate_limit {
115                debug!(target: "rustydht_lib::Throttler", "{} is throttled for {:?}. {} packets on record", ip, self.naughty_timeout, found.packets);
116                found.expiration = now + self.naughty_timeout;
117                return true;
118            }
119        } else if let Some(lamest) = lamest {
120            lamest.packets = count.unwrap_or(1);
121            lamest.expiration = now + self.period;
122            lamest.ip = ip;
123            lamest.creation_time = now;
124
125            if lamest.packets > self.rate_limit {
126                debug!(target: "rustydht_lib::Throttler", "{} is throttled for {:?}. {} packets on record", ip, self.naughty_timeout, lamest.packets);
127                lamest.expiration = now + self.naughty_timeout;
128                return true;
129            }
130        } else {
131            panic!("This should never happen ;)");
132        }
133
134        false
135    }
136
137    pub fn get_num_records(&self) -> usize {
138        NUM_RECORDS
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use std::convert::TryInto;
146    use std::ops::Add;
147
148    #[test]
149    // Tests that throttling kicks in and expires with a single IP
150    fn test_one_two_punch() {
151        let mut throttler = Throttler::<32>::new(
152            1,
153            Duration::from_secs(5),
154            Duration::from_secs(1),
155            Duration::from_secs(10),
156        );
157        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
158
159        assert!(!throttler.check_throttle(ip, None, None));
160        assert!(throttler.check_throttle(ip, None, None));
161        let fake_time = Instant::now().add(Duration::from_secs(2));
162        assert!(!throttler.check_throttle(ip, Some(fake_time), None));
163    }
164
165    #[test]
166    // Tests that throttling kicks in and expires even if the bookkeeping is already 'full'
167    fn test_lots_of_ips() {
168        let mut throttler = Throttler::<32>::new(
169            1,
170            Duration::from_secs(5),
171            Duration::from_secs(1),
172            Duration::from_secs(10),
173        );
174
175        let mut ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 0));
176        for a in 0..throttler.get_num_records() + 1 {
177            let last_octet: u8 = a.try_into().unwrap();
178            ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, last_octet));
179            assert!(!throttler.check_throttle(ip, None, None));
180        }
181        assert!(throttler.check_throttle(ip, None, None));
182        let fake_time = Instant::now().add(Duration::from_secs(2));
183        assert!(!throttler.check_throttle(ip, Some(fake_time), None));
184    }
185
186    #[test]
187    // Tests that a record is reset after the max tracking period has been reached
188    fn test_max_tracking() {
189        let mut throttler = Throttler::<32>::new(
190            1,
191            Duration::from_secs(5),
192            Duration::from_secs(100),
193            Duration::from_secs(10),
194        );
195        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
196
197        assert!(!throttler.check_throttle(ip, None, None));
198        assert!(throttler.check_throttle(ip, None, None));
199        let fake_time = Instant::now().add(Duration::from_secs(9));
200        assert!(throttler.check_throttle(ip, Some(fake_time), None));
201
202        let fake_time = Instant::now().add(Duration::from_secs(11));
203        assert!(!throttler.check_throttle(ip, Some(fake_time), None));
204    }
205
206    #[test]
207    /// Tests that the throttler avoids overflowing packet counts if that happens somehow
208    fn test_avoids_overflow() {
209        let mut throttler = Throttler::<32>::new(
210            1,
211            Duration::from_secs(5),
212            Duration::from_secs(100),
213            Duration::from_secs(10),
214        );
215        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
216        assert!(throttler.check_throttle(ip, None, Some(PacketCount::MAX)));
217
218        // Should stil be throttled, but won't panic
219        assert!(throttler.check_throttle(ip, None, None));
220    }
221}