Skip to main content

rusmes_core/
rate_limit.rs

1//! Rate limiting for connection and message processing
2//!
3//! Provides per-IP, per-sender, and combined IP+sender rate limiting.
4//! Bucket state can be persisted to a JSON file and reloaded on startup,
5//! so limits survive server restarts.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::net::IpAddr;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinHandle;
15
16// ── Key types ──────────────────────────────────────────────────────────────
17
18/// Identifies the axis on which a rate limit is applied.
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub enum RateLimitKey {
21    /// Limit by remote client IP address
22    Ip(IpAddr),
23    /// Limit by MAIL FROM envelope sender address
24    Sender(String),
25    /// Limit by (IP, sender) pair simultaneously
26    IpAndSender(IpAddr, String),
27}
28
29impl RateLimitKey {
30    /// Serialize to a compact string suitable for use as a JSON map key.
31    fn to_key_string(&self) -> String {
32        match self {
33            RateLimitKey::Ip(ip) => format!("ip:{}", ip),
34            RateLimitKey::Sender(addr) => format!("sender:{}", addr),
35            RateLimitKey::IpAndSender(ip, addr) => format!("ip+sender:{}:{}", ip, addr),
36        }
37    }
38}
39
40// ── Configuration ──────────────────────────────────────────────────────────
41
42/// Rate limiter configuration
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RateLimitConfig {
45    /// Maximum connections per IP address (sliding window)
46    pub max_connections_per_ip: usize,
47    /// Maximum messages per window per rate-limit key
48    pub max_messages_per_window: usize,
49    /// Duration of the rate-limit time window
50    #[serde(with = "duration_secs_serde")]
51    pub window_duration: Duration,
52    /// How often (seconds) the bucket state is persisted to disk.
53    /// None disables persistence.
54    pub persist_interval_secs: Option<u64>,
55    /// Directory where `ratelimit.json` is written.
56    /// None disables persistence.
57    pub runtime_dir: Option<PathBuf>,
58}
59
60impl Default for RateLimitConfig {
61    fn default() -> Self {
62        Self {
63            max_connections_per_ip: 10,
64            max_messages_per_window: 100,
65            window_duration: Duration::from_secs(3600), // 1 hour
66            persist_interval_secs: Some(60),
67            runtime_dir: None,
68        }
69    }
70}
71
72mod duration_secs_serde {
73    use serde::{Deserialize, Deserializer, Serialize, Serializer};
74    use std::time::Duration;
75
76    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        duration.as_secs().serialize(serializer)
81    }
82
83    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
84    where
85        D: Deserializer<'de>,
86    {
87        let secs = u64::deserialize(deserializer)?;
88        Ok(Duration::from_secs(secs))
89    }
90}
91
92// ── Bucket state (serializable) ────────────────────────────────────────────
93
94/// A single message-count bucket entry — serializable for persistence.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96struct BucketEntry {
97    count: usize,
98    /// Unix timestamp (seconds) of when this window started
99    window_start_secs: u64,
100}
101
102impl BucketEntry {
103    fn new(now: Instant) -> Self {
104        Self {
105            count: 1,
106            window_start_secs: unix_secs_from_instant(now),
107        }
108    }
109
110    fn is_expired(&self, window_duration: Duration) -> bool {
111        let elapsed = unix_secs_now().saturating_sub(self.window_start_secs);
112        elapsed >= window_duration.as_secs()
113    }
114}
115
116/// Snapshot that maps string-keyed buckets to their entry data
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118struct BucketSnapshot {
119    /// message-count buckets, keyed by `RateLimitKey::to_key_string()`
120    messages: HashMap<String, BucketEntry>,
121}
122
123// ── Connection counter (not persisted — transient, session-scoped) ─────────
124
125#[derive(Debug, Clone)]
126struct ConnectionEntry {
127    count: usize,
128    first_seen: Instant,
129}
130
131// ── RateLimiter ───────────────────────────────────────────────────────────
132
133/// Rate limiter for SMTP connections and messages.
134///
135/// Supports three keying strategies:
136///  - `RateLimitKey::Ip` — the classic per-IP limit
137///  - `RateLimitKey::Sender` — per MAIL FROM address
138///  - `RateLimitKey::IpAndSender` — combined, tightest control
139///
140/// State is periodically snapshotted to `<runtime_dir>/ratelimit.json`
141/// (if `runtime_dir` is configured) and re-loaded on startup.
142pub struct RateLimiter {
143    config: Arc<RwLock<RateLimitConfig>>,
144    connections: Arc<Mutex<HashMap<IpAddr, ConnectionEntry>>>,
145    buckets: Arc<Mutex<HashMap<String, BucketEntry>>>,
146}
147
148impl RateLimiter {
149    /// Create a new rate limiter.
150    ///
151    /// This constructor is sync-callable and does **not** spawn any background
152    /// tasks; it can therefore be invoked outside of a Tokio runtime (e.g. in
153    /// `#[test]` blocks or during synchronous wiring code).
154    ///
155    /// To enable periodic snapshotting of bucket state to disk, call
156    /// [`RateLimiter::start_persistence_task`] from inside an async context
157    /// after construction.
158    pub fn new(config: RateLimitConfig) -> Self {
159        let buckets = Arc::new(Mutex::new(HashMap::new()));
160        let config_arc = Arc::new(RwLock::new(config));
161
162        Self {
163            config: Arc::clone(&config_arc),
164            connections: Arc::new(Mutex::new(HashMap::new())),
165            buckets: Arc::clone(&buckets),
166        }
167    }
168
169    /// Create a rate limiter and immediately restore persisted state from `snapshot_path`.
170    ///
171    /// This is the production constructor; `new()` is the simpler form that
172    /// relies on the runtime_dir in config. Use this to control the path explicitly
173    /// (handy in tests).
174    ///
175    /// As with [`RateLimiter::new`], no background persistence task is spawned —
176    /// call [`RateLimiter::start_persistence_task`] explicitly afterwards.
177    pub async fn new_with_restore(config: RateLimitConfig, snapshot_path: &Path) -> Self {
178        let buckets = Arc::new(Mutex::new(HashMap::new()));
179
180        // Try to load persisted state
181        if let Err(e) = restore_from_file(&buckets, snapshot_path).await {
182            tracing::warn!(
183                "Rate limit state not restored from {:?}: {}",
184                snapshot_path,
185                e
186            );
187        } else {
188            tracing::info!("Rate limit state restored from {:?}", snapshot_path);
189        }
190
191        let config_arc = Arc::new(RwLock::new(config));
192
193        Self {
194            config: config_arc,
195            connections: Arc::new(Mutex::new(HashMap::new())),
196            buckets,
197        }
198    }
199
200    /// Start the background persistence task.
201    ///
202    /// Spawns a Tokio task that snapshots the message-bucket state to
203    /// `<runtime_dir>/ratelimit.json` every `interval`. Returns the
204    /// [`JoinHandle`] so callers can manage the task lifecycle if desired.
205    ///
206    /// **Must be called from within a Tokio runtime.**
207    pub fn start_persistence_task(
208        &self,
209        runtime_dir: PathBuf,
210        interval: Duration,
211    ) -> JoinHandle<()> {
212        let buckets = Arc::clone(&self.buckets);
213        tokio::spawn(async move {
214            persistence_task(runtime_dir, interval, buckets).await;
215        })
216    }
217
218    /// Snapshot the current bucket state to `path` (JSON format).
219    pub async fn snapshot_to_file(&self, path: &Path) -> anyhow::Result<()> {
220        let guard = self.buckets.lock().await;
221        snapshot_to_file_locked(&guard, path).await
222    }
223
224    /// Restore bucket state from a JSON snapshot file.
225    pub async fn restore_from_file(&self, path: &Path) -> anyhow::Result<()> {
226        restore_from_file(&self.buckets, path).await
227    }
228
229    /// Update the rate limiter configuration (hot-reload support)
230    pub async fn update_config(&self, new_config: RateLimitConfig) {
231        let mut config = self.config.write().await;
232        *config = new_config;
233    }
234
235    /// Check if a connection from this IP is allowed
236    pub async fn allow_connection(&self, ip: IpAddr) -> bool {
237        let config = self.config.read().await;
238        let mut connections = self.connections.lock().await;
239
240        // Clean up old entries
241        let now = Instant::now();
242        let window_duration = config.window_duration;
243        connections.retain(|_, entry| now.duration_since(entry.first_seen) < window_duration);
244
245        // Check current count
246        let max_connections = config.max_connections_per_ip;
247        match connections.get_mut(&ip) {
248            Some(entry) => {
249                if entry.count >= max_connections {
250                    tracing::warn!("Connection rate limit exceeded for IP: {}", ip);
251                    false
252                } else {
253                    entry.count += 1;
254                    true
255                }
256            }
257            None => {
258                connections.insert(
259                    ip,
260                    ConnectionEntry {
261                        count: 1,
262                        first_seen: now,
263                    },
264                );
265                true
266            }
267        }
268    }
269
270    /// Release a connection
271    pub async fn release_connection(&self, ip: IpAddr) {
272        let mut connections = self.connections.lock().await;
273        if let Some(entry) = connections.get_mut(&ip) {
274            if entry.count > 0 {
275                entry.count -= 1;
276            }
277            if entry.count == 0 {
278                connections.remove(&ip);
279            }
280        }
281    }
282
283    /// Check if a message for the given key is allowed (generic key variant).
284    ///
285    /// This is the primary per-sender/per-IP message check entry point.
286    pub async fn allow_message_keyed(&self, key: &RateLimitKey) -> bool {
287        let config = self.config.read().await;
288        let max_messages = config.max_messages_per_window;
289        let window_duration = config.window_duration;
290        drop(config); // release read lock before locking buckets
291
292        let key_str = key.to_key_string();
293        let mut buckets = self.buckets.lock().await;
294
295        match buckets.get_mut(&key_str) {
296            Some(entry) => {
297                if entry.is_expired(window_duration) {
298                    // Reset window
299                    *entry = BucketEntry::new(Instant::now());
300                    true
301                } else if entry.count >= max_messages {
302                    tracing::warn!("Message rate limit exceeded for key: {}", key_str);
303                    false
304                } else {
305                    entry.count += 1;
306                    true
307                }
308            }
309            None => {
310                buckets.insert(key_str, BucketEntry::new(Instant::now()));
311                true
312            }
313        }
314    }
315
316    /// Check if a message from this IP is allowed (legacy IP-only API for backwards compat)
317    pub async fn allow_message(&self, ip: IpAddr) -> bool {
318        self.allow_message_keyed(&RateLimitKey::Ip(ip)).await
319    }
320
321    /// Check if a message from this sender is allowed.
322    pub async fn allow_message_from_sender(&self, sender: &str) -> bool {
323        self.allow_message_keyed(&RateLimitKey::Sender(sender.to_string()))
324            .await
325    }
326
327    /// Check if a message is allowed based on both IP and sender.
328    pub async fn allow_message_ip_and_sender(&self, ip: IpAddr, sender: &str) -> bool {
329        self.allow_message_keyed(&RateLimitKey::IpAndSender(ip, sender.to_string()))
330            .await
331    }
332
333    /// Get current connection count for an IP
334    pub async fn get_connection_count(&self, ip: IpAddr) -> usize {
335        let connections = self.connections.lock().await;
336        connections.get(&ip).map(|e| e.count).unwrap_or(0)
337    }
338
339    /// Get current message count for a key (for debugging/testing)
340    pub async fn get_message_count_keyed(&self, key: &RateLimitKey) -> usize {
341        let buckets = self.buckets.lock().await;
342        buckets
343            .get(&key.to_key_string())
344            .map(|e| e.count)
345            .unwrap_or(0)
346    }
347
348    /// Get current message count for an IP (legacy)
349    pub async fn get_message_count(&self, ip: IpAddr) -> usize {
350        self.get_message_count_keyed(&RateLimitKey::Ip(ip)).await
351    }
352}
353
354// ── Persistence helpers ───────────────────────────────────────────────────
355
356fn ratelimit_file_path(runtime_dir: &Path) -> PathBuf {
357    runtime_dir.join("ratelimit.json")
358}
359
360async fn snapshot_to_file_locked(
361    buckets: &HashMap<String, BucketEntry>,
362    path: &Path,
363) -> anyhow::Result<()> {
364    let snapshot = BucketSnapshot {
365        messages: buckets.clone(),
366    };
367    let json = serde_json::to_string_pretty(&snapshot)?;
368    tokio::fs::write(path, json).await?;
369    Ok(())
370}
371
372async fn restore_from_file(
373    buckets: &Mutex<HashMap<String, BucketEntry>>,
374    path: &Path,
375) -> anyhow::Result<()> {
376    if !tokio::fs::try_exists(path).await? {
377        return Ok(());
378    }
379    let json = tokio::fs::read_to_string(path).await?;
380    let snapshot: BucketSnapshot = serde_json::from_str(&json)?;
381    let mut guard = buckets.lock().await;
382    *guard = snapshot.messages;
383    Ok(())
384}
385
386/// Background task that periodically persists rate limit state.
387async fn persistence_task(
388    runtime_dir: PathBuf,
389    interval: Duration,
390    buckets: Arc<Mutex<HashMap<String, BucketEntry>>>,
391) {
392    let path = ratelimit_file_path(&runtime_dir);
393    loop {
394        tokio::time::sleep(interval).await;
395
396        let guard = buckets.lock().await;
397        if let Err(e) = snapshot_to_file_locked(&guard, &path).await {
398            tracing::warn!("Failed to persist rate limit state to {:?}: {}", path, e);
399        } else {
400            tracing::debug!("Rate limit state persisted to {:?}", path);
401        }
402    }
403}
404
405// ── Utility ───────────────────────────────────────────────────────────────
406
407/// Current Unix timestamp in seconds (wall-clock, not monotonic)
408fn unix_secs_now() -> u64 {
409    std::time::SystemTime::now()
410        .duration_since(std::time::UNIX_EPOCH)
411        .map(|d| d.as_secs())
412        .unwrap_or(0)
413}
414
415/// Convert a monotonic Instant to an approximate Unix timestamp.
416/// Used only for the initial `window_start_secs` field.
417fn unix_secs_from_instant(_instant: Instant) -> u64 {
418    unix_secs_now()
419}
420
421// ── Tests ─────────────────────────────────────────────────────────────────
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use std::net::{IpAddr, Ipv4Addr};
427
428    fn test_config(max_messages: usize) -> RateLimitConfig {
429        RateLimitConfig {
430            max_connections_per_ip: 2,
431            max_messages_per_window: max_messages,
432            window_duration: Duration::from_secs(3600),
433            persist_interval_secs: None, // Don't spawn the background task interval
434            runtime_dir: None,
435        }
436    }
437
438    #[tokio::test]
439    async fn test_connection_limit() {
440        let limiter = RateLimiter::new(test_config(100));
441        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
442
443        assert!(limiter.allow_connection(ip).await);
444        assert!(limiter.allow_connection(ip).await);
445        assert!(!limiter.allow_connection(ip).await);
446
447        limiter.release_connection(ip).await;
448        assert!(limiter.allow_connection(ip).await);
449    }
450
451    #[tokio::test]
452    async fn test_message_limit() {
453        let config = RateLimitConfig {
454            max_connections_per_ip: 10,
455            max_messages_per_window: 2,
456            ..test_config(2)
457        };
458        let limiter = RateLimiter::new(config);
459        let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1));
460
461        assert!(limiter.allow_message(ip).await);
462        assert!(limiter.allow_message(ip).await);
463        assert!(!limiter.allow_message(ip).await);
464    }
465
466    #[tokio::test]
467    async fn per_sender_rate_limit_sixth_rejected() {
468        // 5 messages from spammer@x.com with limit=5 → 6th rejected
469        let config = RateLimitConfig {
470            max_messages_per_window: 5,
471            persist_interval_secs: None,
472            ..Default::default()
473        };
474        let limiter = RateLimiter::new(config);
475        let sender = "spammer@x.com";
476
477        for i in 1..=5 {
478            let allowed = limiter.allow_message_from_sender(sender).await;
479            assert!(allowed, "Message {} should be allowed", i);
480        }
481
482        let sixth_allowed = limiter.allow_message_from_sender(sender).await;
483        assert!(!sixth_allowed, "6th message should be rejected");
484    }
485
486    #[tokio::test]
487    async fn rate_limit_persistence_roundtrip() {
488        let tmp_dir = std::env::temp_dir().join(format!("rusmes_rl_test_{}", uuid::Uuid::new_v4()));
489        tokio::fs::create_dir_all(&tmp_dir).await.unwrap();
490        let snapshot_path = tmp_dir.join("ratelimit.json");
491
492        // Create a limiter, add some bucket state
493        {
494            let config = RateLimitConfig {
495                max_messages_per_window: 100,
496                persist_interval_secs: None,
497                runtime_dir: None,
498                ..Default::default()
499            };
500            let limiter = RateLimiter::new(config);
501
502            // Record 3 messages from spammer@example.com
503            for _ in 0..3 {
504                limiter
505                    .allow_message_from_sender("spammer@example.com")
506                    .await;
507            }
508
509            // Snapshot
510            limiter.snapshot_to_file(&snapshot_path).await.unwrap();
511        }
512
513        // Reload into a new limiter
514        {
515            let config = RateLimitConfig {
516                max_messages_per_window: 100,
517                persist_interval_secs: None,
518                runtime_dir: None,
519                ..Default::default()
520            };
521            let limiter = RateLimiter::new_with_restore(config, &snapshot_path).await;
522
523            let count = limiter
524                .get_message_count_keyed(&RateLimitKey::Sender("spammer@example.com".to_string()))
525                .await;
526            assert_eq!(count, 3, "Bucket count should be preserved across restart");
527        }
528
529        // Cleanup
530        let _ = tokio::fs::remove_dir_all(&tmp_dir).await;
531    }
532
533    #[tokio::test]
534    async fn rate_limit_ip_and_sender_key() {
535        let config = RateLimitConfig {
536            max_messages_per_window: 2,
537            persist_interval_secs: None,
538            ..Default::default()
539        };
540        let limiter = RateLimiter::new(config);
541        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
542        let sender = "user@spammer.com";
543
544        assert!(limiter.allow_message_ip_and_sender(ip, sender).await);
545        assert!(limiter.allow_message_ip_and_sender(ip, sender).await);
546        assert!(!limiter.allow_message_ip_and_sender(ip, sender).await);
547
548        // Different IP with same sender should be independent
549        let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
550        assert!(limiter.allow_message_ip_and_sender(ip2, sender).await);
551    }
552}