1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::net::IpAddr;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinHandle;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub enum RateLimitKey {
21 Ip(IpAddr),
23 Sender(String),
25 IpAndSender(IpAddr, String),
27}
28
29impl RateLimitKey {
30 fn to_key_string(&self) -> String {
32 match self {
33 RateLimitKey::Ip(ip) => format!("ip:{}", ip),
34 RateLimitKey::Sender(addr) => format!("sender:{}", addr),
35 RateLimitKey::IpAndSender(ip, addr) => format!("ip+sender:{}:{}", ip, addr),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RateLimitConfig {
45 pub max_connections_per_ip: usize,
47 pub max_messages_per_window: usize,
49 #[serde(with = "duration_secs_serde")]
51 pub window_duration: Duration,
52 pub persist_interval_secs: Option<u64>,
55 pub runtime_dir: Option<PathBuf>,
58}
59
60impl Default for RateLimitConfig {
61 fn default() -> Self {
62 Self {
63 max_connections_per_ip: 10,
64 max_messages_per_window: 100,
65 window_duration: Duration::from_secs(3600), persist_interval_secs: Some(60),
67 runtime_dir: None,
68 }
69 }
70}
71
72mod duration_secs_serde {
73 use serde::{Deserialize, Deserializer, Serialize, Serializer};
74 use std::time::Duration;
75
76 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 duration.as_secs().serialize(serializer)
81 }
82
83 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
84 where
85 D: Deserializer<'de>,
86 {
87 let secs = u64::deserialize(deserializer)?;
88 Ok(Duration::from_secs(secs))
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
96struct BucketEntry {
97 count: usize,
98 window_start_secs: u64,
100}
101
102impl BucketEntry {
103 fn new(now: Instant) -> Self {
104 Self {
105 count: 1,
106 window_start_secs: unix_secs_from_instant(now),
107 }
108 }
109
110 fn is_expired(&self, window_duration: Duration) -> bool {
111 let elapsed = unix_secs_now().saturating_sub(self.window_start_secs);
112 elapsed >= window_duration.as_secs()
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118struct BucketSnapshot {
119 messages: HashMap<String, BucketEntry>,
121}
122
123#[derive(Debug, Clone)]
126struct ConnectionEntry {
127 count: usize,
128 first_seen: Instant,
129}
130
131pub struct RateLimiter {
143 config: Arc<RwLock<RateLimitConfig>>,
144 connections: Arc<Mutex<HashMap<IpAddr, ConnectionEntry>>>,
145 buckets: Arc<Mutex<HashMap<String, BucketEntry>>>,
146}
147
148impl RateLimiter {
149 pub fn new(config: RateLimitConfig) -> Self {
159 let buckets = Arc::new(Mutex::new(HashMap::new()));
160 let config_arc = Arc::new(RwLock::new(config));
161
162 Self {
163 config: Arc::clone(&config_arc),
164 connections: Arc::new(Mutex::new(HashMap::new())),
165 buckets: Arc::clone(&buckets),
166 }
167 }
168
169 pub async fn new_with_restore(config: RateLimitConfig, snapshot_path: &Path) -> Self {
178 let buckets = Arc::new(Mutex::new(HashMap::new()));
179
180 if let Err(e) = restore_from_file(&buckets, snapshot_path).await {
182 tracing::warn!(
183 "Rate limit state not restored from {:?}: {}",
184 snapshot_path,
185 e
186 );
187 } else {
188 tracing::info!("Rate limit state restored from {:?}", snapshot_path);
189 }
190
191 let config_arc = Arc::new(RwLock::new(config));
192
193 Self {
194 config: config_arc,
195 connections: Arc::new(Mutex::new(HashMap::new())),
196 buckets,
197 }
198 }
199
200 pub fn start_persistence_task(
208 &self,
209 runtime_dir: PathBuf,
210 interval: Duration,
211 ) -> JoinHandle<()> {
212 let buckets = Arc::clone(&self.buckets);
213 tokio::spawn(async move {
214 persistence_task(runtime_dir, interval, buckets).await;
215 })
216 }
217
218 pub async fn snapshot_to_file(&self, path: &Path) -> anyhow::Result<()> {
220 let guard = self.buckets.lock().await;
221 snapshot_to_file_locked(&guard, path).await
222 }
223
224 pub async fn restore_from_file(&self, path: &Path) -> anyhow::Result<()> {
226 restore_from_file(&self.buckets, path).await
227 }
228
229 pub async fn update_config(&self, new_config: RateLimitConfig) {
231 let mut config = self.config.write().await;
232 *config = new_config;
233 }
234
235 pub async fn allow_connection(&self, ip: IpAddr) -> bool {
237 let config = self.config.read().await;
238 let mut connections = self.connections.lock().await;
239
240 let now = Instant::now();
242 let window_duration = config.window_duration;
243 connections.retain(|_, entry| now.duration_since(entry.first_seen) < window_duration);
244
245 let max_connections = config.max_connections_per_ip;
247 match connections.get_mut(&ip) {
248 Some(entry) => {
249 if entry.count >= max_connections {
250 tracing::warn!("Connection rate limit exceeded for IP: {}", ip);
251 false
252 } else {
253 entry.count += 1;
254 true
255 }
256 }
257 None => {
258 connections.insert(
259 ip,
260 ConnectionEntry {
261 count: 1,
262 first_seen: now,
263 },
264 );
265 true
266 }
267 }
268 }
269
270 pub async fn release_connection(&self, ip: IpAddr) {
272 let mut connections = self.connections.lock().await;
273 if let Some(entry) = connections.get_mut(&ip) {
274 if entry.count > 0 {
275 entry.count -= 1;
276 }
277 if entry.count == 0 {
278 connections.remove(&ip);
279 }
280 }
281 }
282
283 pub async fn allow_message_keyed(&self, key: &RateLimitKey) -> bool {
287 let config = self.config.read().await;
288 let max_messages = config.max_messages_per_window;
289 let window_duration = config.window_duration;
290 drop(config); let key_str = key.to_key_string();
293 let mut buckets = self.buckets.lock().await;
294
295 match buckets.get_mut(&key_str) {
296 Some(entry) => {
297 if entry.is_expired(window_duration) {
298 *entry = BucketEntry::new(Instant::now());
300 true
301 } else if entry.count >= max_messages {
302 tracing::warn!("Message rate limit exceeded for key: {}", key_str);
303 false
304 } else {
305 entry.count += 1;
306 true
307 }
308 }
309 None => {
310 buckets.insert(key_str, BucketEntry::new(Instant::now()));
311 true
312 }
313 }
314 }
315
316 pub async fn allow_message(&self, ip: IpAddr) -> bool {
318 self.allow_message_keyed(&RateLimitKey::Ip(ip)).await
319 }
320
321 pub async fn allow_message_from_sender(&self, sender: &str) -> bool {
323 self.allow_message_keyed(&RateLimitKey::Sender(sender.to_string()))
324 .await
325 }
326
327 pub async fn allow_message_ip_and_sender(&self, ip: IpAddr, sender: &str) -> bool {
329 self.allow_message_keyed(&RateLimitKey::IpAndSender(ip, sender.to_string()))
330 .await
331 }
332
333 pub async fn get_connection_count(&self, ip: IpAddr) -> usize {
335 let connections = self.connections.lock().await;
336 connections.get(&ip).map(|e| e.count).unwrap_or(0)
337 }
338
339 pub async fn get_message_count_keyed(&self, key: &RateLimitKey) -> usize {
341 let buckets = self.buckets.lock().await;
342 buckets
343 .get(&key.to_key_string())
344 .map(|e| e.count)
345 .unwrap_or(0)
346 }
347
348 pub async fn get_message_count(&self, ip: IpAddr) -> usize {
350 self.get_message_count_keyed(&RateLimitKey::Ip(ip)).await
351 }
352}
353
354fn ratelimit_file_path(runtime_dir: &Path) -> PathBuf {
357 runtime_dir.join("ratelimit.json")
358}
359
360async fn snapshot_to_file_locked(
361 buckets: &HashMap<String, BucketEntry>,
362 path: &Path,
363) -> anyhow::Result<()> {
364 let snapshot = BucketSnapshot {
365 messages: buckets.clone(),
366 };
367 let json = serde_json::to_string_pretty(&snapshot)?;
368 tokio::fs::write(path, json).await?;
369 Ok(())
370}
371
372async fn restore_from_file(
373 buckets: &Mutex<HashMap<String, BucketEntry>>,
374 path: &Path,
375) -> anyhow::Result<()> {
376 if !tokio::fs::try_exists(path).await? {
377 return Ok(());
378 }
379 let json = tokio::fs::read_to_string(path).await?;
380 let snapshot: BucketSnapshot = serde_json::from_str(&json)?;
381 let mut guard = buckets.lock().await;
382 *guard = snapshot.messages;
383 Ok(())
384}
385
386async fn persistence_task(
388 runtime_dir: PathBuf,
389 interval: Duration,
390 buckets: Arc<Mutex<HashMap<String, BucketEntry>>>,
391) {
392 let path = ratelimit_file_path(&runtime_dir);
393 loop {
394 tokio::time::sleep(interval).await;
395
396 let guard = buckets.lock().await;
397 if let Err(e) = snapshot_to_file_locked(&guard, &path).await {
398 tracing::warn!("Failed to persist rate limit state to {:?}: {}", path, e);
399 } else {
400 tracing::debug!("Rate limit state persisted to {:?}", path);
401 }
402 }
403}
404
405fn unix_secs_now() -> u64 {
409 std::time::SystemTime::now()
410 .duration_since(std::time::UNIX_EPOCH)
411 .map(|d| d.as_secs())
412 .unwrap_or(0)
413}
414
415fn unix_secs_from_instant(_instant: Instant) -> u64 {
418 unix_secs_now()
419}
420
421#[cfg(test)]
424mod tests {
425 use super::*;
426 use std::net::{IpAddr, Ipv4Addr};
427
428 fn test_config(max_messages: usize) -> RateLimitConfig {
429 RateLimitConfig {
430 max_connections_per_ip: 2,
431 max_messages_per_window: max_messages,
432 window_duration: Duration::from_secs(3600),
433 persist_interval_secs: None, runtime_dir: None,
435 }
436 }
437
438 #[tokio::test]
439 async fn test_connection_limit() {
440 let limiter = RateLimiter::new(test_config(100));
441 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
442
443 assert!(limiter.allow_connection(ip).await);
444 assert!(limiter.allow_connection(ip).await);
445 assert!(!limiter.allow_connection(ip).await);
446
447 limiter.release_connection(ip).await;
448 assert!(limiter.allow_connection(ip).await);
449 }
450
451 #[tokio::test]
452 async fn test_message_limit() {
453 let config = RateLimitConfig {
454 max_connections_per_ip: 10,
455 max_messages_per_window: 2,
456 ..test_config(2)
457 };
458 let limiter = RateLimiter::new(config);
459 let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1));
460
461 assert!(limiter.allow_message(ip).await);
462 assert!(limiter.allow_message(ip).await);
463 assert!(!limiter.allow_message(ip).await);
464 }
465
466 #[tokio::test]
467 async fn per_sender_rate_limit_sixth_rejected() {
468 let config = RateLimitConfig {
470 max_messages_per_window: 5,
471 persist_interval_secs: None,
472 ..Default::default()
473 };
474 let limiter = RateLimiter::new(config);
475 let sender = "spammer@x.com";
476
477 for i in 1..=5 {
478 let allowed = limiter.allow_message_from_sender(sender).await;
479 assert!(allowed, "Message {} should be allowed", i);
480 }
481
482 let sixth_allowed = limiter.allow_message_from_sender(sender).await;
483 assert!(!sixth_allowed, "6th message should be rejected");
484 }
485
486 #[tokio::test]
487 async fn rate_limit_persistence_roundtrip() {
488 let tmp_dir = std::env::temp_dir().join(format!("rusmes_rl_test_{}", uuid::Uuid::new_v4()));
489 tokio::fs::create_dir_all(&tmp_dir).await.unwrap();
490 let snapshot_path = tmp_dir.join("ratelimit.json");
491
492 {
494 let config = RateLimitConfig {
495 max_messages_per_window: 100,
496 persist_interval_secs: None,
497 runtime_dir: None,
498 ..Default::default()
499 };
500 let limiter = RateLimiter::new(config);
501
502 for _ in 0..3 {
504 limiter
505 .allow_message_from_sender("spammer@example.com")
506 .await;
507 }
508
509 limiter.snapshot_to_file(&snapshot_path).await.unwrap();
511 }
512
513 {
515 let config = RateLimitConfig {
516 max_messages_per_window: 100,
517 persist_interval_secs: None,
518 runtime_dir: None,
519 ..Default::default()
520 };
521 let limiter = RateLimiter::new_with_restore(config, &snapshot_path).await;
522
523 let count = limiter
524 .get_message_count_keyed(&RateLimitKey::Sender("spammer@example.com".to_string()))
525 .await;
526 assert_eq!(count, 3, "Bucket count should be preserved across restart");
527 }
528
529 let _ = tokio::fs::remove_dir_all(&tmp_dir).await;
531 }
532
533 #[tokio::test]
534 async fn rate_limit_ip_and_sender_key() {
535 let config = RateLimitConfig {
536 max_messages_per_window: 2,
537 persist_interval_secs: None,
538 ..Default::default()
539 };
540 let limiter = RateLimiter::new(config);
541 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
542 let sender = "user@spammer.com";
543
544 assert!(limiter.allow_message_ip_and_sender(ip, sender).await);
545 assert!(limiter.allow_message_ip_and_sender(ip, sender).await);
546 assert!(!limiter.allow_message_ip_and_sender(ip, sender).await);
547
548 let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
550 assert!(limiter.allow_message_ip_and_sender(ip2, sender).await);
551 }
552}