1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tokio::task::JoinHandle;
4
5type Map = HashMap<Arc<std::net::Ipv6Addr>, u64>;
6
7pub struct IpRate {
9 origin: tokio::time::Instant,
10 map: Arc<Mutex<Map>>,
11 disabled: bool,
12 limit: u64,
13 burst: u64,
14 ip_deny: crate::ip_deny::IpDeny,
15}
16
17impl IpRate {
18 pub fn new(config: Arc<crate::Config>) -> Self {
20 Self {
21 origin: tokio::time::Instant::now(),
22 map: Arc::new(Mutex::new(HashMap::new())),
23 disabled: config.disable_rate_limiting,
24 limit: config.limit_ip_byte_nanos() as u64,
25 burst: config.limit_ip_byte_burst as u64
26 * config.limit_ip_byte_nanos() as u64,
27 ip_deny: crate::ip_deny::IpDeny::new(config),
28 }
29 }
30
31 pub fn prune(&self) {
37 let now = self.origin.elapsed().as_nanos() as u64;
38 self.map.lock().unwrap().retain(|_, cur| {
39 if now <= *cur {
40 true
41 } else {
42 now - *cur < 10_000_000_000
47 }
48 });
49 }
50
51 pub async fn is_blocked(&self, ip: &Arc<std::net::Ipv6Addr>) -> bool {
53 self.ip_deny.is_blocked(ip).await
54 }
55
56 pub async fn is_ok(
58 &self,
59 ip: &Arc<std::net::Ipv6Addr>,
60 bytes: usize,
61 ) -> bool {
62 if self.disabled {
63 return true;
64 }
65
66 let rate_add = bytes as u64 * self.limit;
68
69 let now = self.origin.elapsed().as_nanos() as u64;
71
72 let is_ok = {
73 let mut lock = self.map.lock().unwrap();
75
76 let e = lock.entry(ip.clone()).or_insert(now);
78
79 let cur = std::cmp::max(*e, now) + rate_add;
82
83 *e = cur;
85
86 cur - now <= self.burst
88 };
89
90 if !is_ok {
91 self.ip_deny.block(ip).await;
92 }
93
94 is_ok
95 }
96}
97
98pub fn spawn_prune_task(ip_rate: Arc<IpRate>) -> JoinHandle<()> {
100 let ip_rate = Arc::downgrade(&ip_rate);
101 tokio::task::spawn(async move {
102 loop {
103 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
104 if let Some(ip_rate) = ip_rate.upgrade() {
105 ip_rate.prune();
106 } else {
107 break;
108 }
109 }
110 })
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn test_new(limit: u64, burst: u64) -> IpRate {
118 IpRate {
119 origin: tokio::time::Instant::now(),
120 map: Arc::new(Mutex::new(HashMap::new())),
121 disabled: false,
122 limit,
123 burst,
124 ip_deny: crate::ip_deny::IpDeny::new(Arc::new(
125 crate::Config::default(),
126 )),
127 }
128 }
129
130 #[tokio::test(flavor = "current_thread", start_paused = true)]
131 async fn check_one_to_one() {
132 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
133
134 let rate = test_new(1, 1);
135
136 for _ in 0..10 {
137 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
139 assert!(rate.is_ok(&addr1, 1).await);
140 }
141
142 assert!(!rate.is_ok(&addr1, 1).await);
144
145 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
146
147 rate.prune();
149 assert_eq!(1, rate.map.lock().unwrap().len());
150
151 tokio::time::advance(std::time::Duration::from_secs(10)).await;
152
153 rate.prune();
155 assert_eq!(1, rate.map.lock().unwrap().len());
156
157 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
159 rate.prune();
160 assert_eq!(0, rate.map.lock().unwrap().len());
161 }
162
163 #[tokio::test(flavor = "current_thread", start_paused = true)]
164 async fn check_burst() {
165 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
166
167 let rate = test_new(1, 5);
168
169 for _ in 0..5 {
170 assert!(rate.is_ok(&addr1, 1).await);
171 }
172
173 assert!(!rate.is_ok(&addr1, 1).await);
174
175 tokio::time::advance(std::time::Duration::from_nanos(2)).await;
176 assert!(rate.is_ok(&addr1, 1).await);
177
178 tokio::time::advance(std::time::Duration::from_secs(10)).await;
179 tokio::time::advance(std::time::Duration::from_nanos(4)).await;
180
181 rate.prune();
182 assert_eq!(1, rate.map.lock().unwrap().len());
183
184 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
185
186 rate.prune();
187 assert_eq!(0, rate.map.lock().unwrap().len());
188 }
189
190 #[tokio::test(flavor = "current_thread", start_paused = true)]
191 async fn check_limit_mult() {
192 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
193
194 let rate = test_new(3, 13);
195
196 assert!(rate.is_ok(&addr1, 2).await);
197 assert!(rate.is_ok(&addr1, 2).await);
198 assert!(!rate.is_ok(&addr1, 2).await);
199
200 tokio::time::advance(std::time::Duration::from_secs(10)).await;
201
202 assert!(rate.is_ok(&addr1, 2).await);
203 assert!(rate.is_ok(&addr1, 2).await);
204 assert!(!rate.is_ok(&addr1, 2).await);
205 }
206}