1use std::net::{IpAddr, Ipv4Addr};
2use std::time::{Duration, Instant};
3
4use log::debug;
5
6type PacketCount = usize;
7
8#[derive(Clone, Copy)]
9struct ThrottlerRecord {
10 ip: IpAddr,
11 packets: PacketCount,
12 expiration: Instant,
13 creation_time: Instant,
14}
15
16impl Default for ThrottlerRecord {
17 fn default() -> Self {
18 let now = Instant::now();
19 ThrottlerRecord {
20 ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
21 packets: 0,
22 expiration: now,
23 creation_time: now,
24 }
25 }
26}
27
28impl ThrottlerRecord {
29 fn clear(&mut self) {
30 let now = Instant::now();
31 self.ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
32 self.packets = 0;
33 self.expiration = now;
34 self.creation_time = now;
35 }
36}
37
38pub struct Throttler<const NUM_RECORDS: usize> {
39 records: [ThrottlerRecord; NUM_RECORDS],
40
41 rate_limit: PacketCount,
42 period: Duration,
43 naughty_timeout: Duration,
44 max_tracking: Duration,
45}
46
47impl<const NUM_RECORDS: usize> Throttler<NUM_RECORDS> {
48 pub fn new(
49 rate_limit: PacketCount,
50 period: Duration,
51 naughty_timeout: Duration,
52 max_tracking: Duration,
53 ) -> Throttler<NUM_RECORDS> {
54 Throttler {
55 records: [ThrottlerRecord::default(); NUM_RECORDS],
56 rate_limit,
57 period,
58 naughty_timeout,
59 max_tracking,
60 }
61 }
62
63 pub fn check_throttle(
65 &mut self,
66 ip: IpAddr,
67 now: Option<Instant>,
68 count: Option<PacketCount>,
69 ) -> bool {
70 let now = match now {
71 Some(instant) => instant,
72 None => Instant::now(),
73 };
74
75 let mut found: Option<&mut ThrottlerRecord> = None;
76 let mut lamest: Option<&mut ThrottlerRecord> = None;
77 for record in &mut self.records {
78 if record.ip == ip {
80 found = Some(record);
81 break;
82 }
83
84 if let Some(lame) = &lamest {
86 if record.packets < lame.packets
87 || (record.packets == lame.packets && record.expiration < lame.expiration)
88 {
89 lamest = Some(record);
90 }
91 } else {
92 lamest = Some(record)
93 }
94 }
95
96 if let Some(found) = found {
97 if let Some(since_creation) = now.checked_duration_since(found.creation_time) {
99 if since_creation > self.max_tracking {
100 found.clear();
101 return false;
102 }
103 }
104
105 if now < found.expiration {
106 found.packets = found
107 .packets
108 .checked_add(count.unwrap_or(1))
109 .unwrap_or(PacketCount::MAX);
110 } else {
111 found.packets = count.unwrap_or(1);
112 found.expiration = now + self.period;
113 }
114 if found.packets > self.rate_limit {
115 debug!(target: "rustydht_lib::Throttler", "{} is throttled for {:?}. {} packets on record", ip, self.naughty_timeout, found.packets);
116 found.expiration = now + self.naughty_timeout;
117 return true;
118 }
119 } else if let Some(lamest) = lamest {
120 lamest.packets = count.unwrap_or(1);
121 lamest.expiration = now + self.period;
122 lamest.ip = ip;
123 lamest.creation_time = now;
124
125 if lamest.packets > self.rate_limit {
126 debug!(target: "rustydht_lib::Throttler", "{} is throttled for {:?}. {} packets on record", ip, self.naughty_timeout, lamest.packets);
127 lamest.expiration = now + self.naughty_timeout;
128 return true;
129 }
130 } else {
131 panic!("This should never happen ;)");
132 }
133
134 false
135 }
136
137 pub fn get_num_records(&self) -> usize {
138 NUM_RECORDS
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::convert::TryInto;
146 use std::ops::Add;
147
148 #[test]
149 fn test_one_two_punch() {
151 let mut throttler = Throttler::<32>::new(
152 1,
153 Duration::from_secs(5),
154 Duration::from_secs(1),
155 Duration::from_secs(10),
156 );
157 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
158
159 assert!(!throttler.check_throttle(ip, None, None));
160 assert!(throttler.check_throttle(ip, None, None));
161 let fake_time = Instant::now().add(Duration::from_secs(2));
162 assert!(!throttler.check_throttle(ip, Some(fake_time), None));
163 }
164
165 #[test]
166 fn test_lots_of_ips() {
168 let mut throttler = Throttler::<32>::new(
169 1,
170 Duration::from_secs(5),
171 Duration::from_secs(1),
172 Duration::from_secs(10),
173 );
174
175 let mut ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 0));
176 for a in 0..throttler.get_num_records() + 1 {
177 let last_octet: u8 = a.try_into().unwrap();
178 ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, last_octet));
179 assert!(!throttler.check_throttle(ip, None, None));
180 }
181 assert!(throttler.check_throttle(ip, None, None));
182 let fake_time = Instant::now().add(Duration::from_secs(2));
183 assert!(!throttler.check_throttle(ip, Some(fake_time), None));
184 }
185
186 #[test]
187 fn test_max_tracking() {
189 let mut throttler = Throttler::<32>::new(
190 1,
191 Duration::from_secs(5),
192 Duration::from_secs(100),
193 Duration::from_secs(10),
194 );
195 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
196
197 assert!(!throttler.check_throttle(ip, None, None));
198 assert!(throttler.check_throttle(ip, None, None));
199 let fake_time = Instant::now().add(Duration::from_secs(9));
200 assert!(throttler.check_throttle(ip, Some(fake_time), None));
201
202 let fake_time = Instant::now().add(Duration::from_secs(11));
203 assert!(!throttler.check_throttle(ip, Some(fake_time), None));
204 }
205
206 #[test]
207 fn test_avoids_overflow() {
209 let mut throttler = Throttler::<32>::new(
210 1,
211 Duration::from_secs(5),
212 Duration::from_secs(100),
213 Duration::from_secs(10),
214 );
215 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 50, 1));
216 assert!(throttler.check_throttle(ip, None, Some(PacketCount::MAX)));
217
218 assert!(throttler.check_throttle(ip, None, None));
220 }
221}