Skip to main content

rusmes_core/
rate_limit.rs

1//! Rate limiting for connection and message processing
2
3use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::{Mutex, RwLock};
8
9/// Rate limiter configuration
10#[derive(Debug, Clone)]
11pub struct RateLimitConfig {
12    /// Maximum connections per IP address
13    pub max_connections_per_ip: usize,
14    /// Maximum messages per hour per IP
15    pub max_messages_per_hour: usize,
16    /// Time window for rate limiting
17    pub window_duration: Duration,
18}
19
20impl Default for RateLimitConfig {
21    fn default() -> Self {
22        Self {
23            max_connections_per_ip: 10,
24            max_messages_per_hour: 100,
25            window_duration: Duration::from_secs(3600), // 1 hour
26        }
27    }
28}
29
30/// Connection counter entry
31#[derive(Debug, Clone)]
32struct ConnectionEntry {
33    count: usize,
34    first_seen: Instant,
35}
36
37/// Message counter entry
38#[derive(Debug, Clone)]
39struct MessageEntry {
40    count: usize,
41    window_start: Instant,
42}
43
44/// Rate limiter for SMTP connections and messages
45pub struct RateLimiter {
46    config: Arc<RwLock<RateLimitConfig>>,
47    connections: Arc<Mutex<HashMap<IpAddr, ConnectionEntry>>>,
48    messages: Arc<Mutex<HashMap<IpAddr, MessageEntry>>>,
49}
50
51impl RateLimiter {
52    /// Create a new rate limiter
53    pub fn new(config: RateLimitConfig) -> Self {
54        Self {
55            config: Arc::new(RwLock::new(config)),
56            connections: Arc::new(Mutex::new(HashMap::new())),
57            messages: Arc::new(Mutex::new(HashMap::new())),
58        }
59    }
60
61    /// Update the rate limiter configuration (hot-reload support)
62    pub async fn update_config(&self, new_config: RateLimitConfig) {
63        let mut config = self.config.write().await;
64        *config = new_config;
65    }
66
67    /// Check if a connection from this IP is allowed
68    pub async fn allow_connection(&self, ip: IpAddr) -> bool {
69        let config = self.config.read().await;
70        let mut connections = self.connections.lock().await;
71
72        // Clean up old entries
73        let now = Instant::now();
74        let window_duration = config.window_duration;
75        connections.retain(|_, entry| now.duration_since(entry.first_seen) < window_duration);
76
77        // Check current count
78        let max_connections = config.max_connections_per_ip;
79        match connections.get_mut(&ip) {
80            Some(entry) => {
81                if entry.count >= max_connections {
82                    tracing::warn!("Connection rate limit exceeded for IP: {}", ip);
83                    false
84                } else {
85                    entry.count += 1;
86                    true
87                }
88            }
89            None => {
90                connections.insert(
91                    ip,
92                    ConnectionEntry {
93                        count: 1,
94                        first_seen: now,
95                    },
96                );
97                true
98            }
99        }
100    }
101
102    /// Release a connection
103    pub async fn release_connection(&self, ip: IpAddr) {
104        let mut connections = self.connections.lock().await;
105        if let Some(entry) = connections.get_mut(&ip) {
106            if entry.count > 0 {
107                entry.count -= 1;
108            }
109            if entry.count == 0 {
110                connections.remove(&ip);
111            }
112        }
113    }
114
115    /// Check if a message from this IP is allowed
116    pub async fn allow_message(&self, ip: IpAddr) -> bool {
117        let config = self.config.read().await;
118        let mut messages = self.messages.lock().await;
119
120        let now = Instant::now();
121        let window_duration = config.window_duration;
122        let max_messages = config.max_messages_per_hour;
123
124        match messages.get_mut(&ip) {
125            Some(entry) => {
126                // Check if we need to reset the window
127                if now.duration_since(entry.window_start) >= window_duration {
128                    entry.count = 1;
129                    entry.window_start = now;
130                    true
131                } else if entry.count >= max_messages {
132                    tracing::warn!("Message rate limit exceeded for IP: {}", ip);
133                    false
134                } else {
135                    entry.count += 1;
136                    true
137                }
138            }
139            None => {
140                messages.insert(
141                    ip,
142                    MessageEntry {
143                        count: 1,
144                        window_start: now,
145                    },
146                );
147                true
148            }
149        }
150    }
151
152    /// Get current connection count for an IP
153    pub async fn get_connection_count(&self, ip: IpAddr) -> usize {
154        let connections = self.connections.lock().await;
155        connections.get(&ip).map(|e| e.count).unwrap_or(0)
156    }
157
158    /// Get current message count for an IP
159    pub async fn get_message_count(&self, ip: IpAddr) -> usize {
160        let messages = self.messages.lock().await;
161        messages.get(&ip).map(|e| e.count).unwrap_or(0)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::net::{IpAddr, Ipv4Addr};
169
170    #[tokio::test]
171    async fn test_connection_limit() {
172        let config = RateLimitConfig {
173            max_connections_per_ip: 2,
174            max_messages_per_hour: 100,
175            window_duration: Duration::from_secs(60),
176        };
177
178        let limiter = RateLimiter::new(config);
179        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
180
181        // First two connections should succeed
182        assert!(limiter.allow_connection(ip).await);
183        assert!(limiter.allow_connection(ip).await);
184
185        // Third should fail
186        assert!(!limiter.allow_connection(ip).await);
187
188        // Release one connection
189        limiter.release_connection(ip).await;
190
191        // Now should succeed again
192        assert!(limiter.allow_connection(ip).await);
193    }
194
195    #[tokio::test]
196    async fn test_message_limit() {
197        let config = RateLimitConfig {
198            max_connections_per_ip: 10,
199            max_messages_per_hour: 2,
200            window_duration: Duration::from_secs(60),
201        };
202
203        let limiter = RateLimiter::new(config);
204        let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1));
205
206        // First two messages should succeed
207        assert!(limiter.allow_message(ip).await);
208        assert!(limiter.allow_message(ip).await);
209
210        // Third should fail
211        assert!(!limiter.allow_message(ip).await);
212    }
213}