rusmes_core/
rate_limit.rs1use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::{Mutex, RwLock};
8
9#[derive(Debug, Clone)]
11pub struct RateLimitConfig {
12 pub max_connections_per_ip: usize,
14 pub max_messages_per_hour: usize,
16 pub window_duration: Duration,
18}
19
20impl Default for RateLimitConfig {
21 fn default() -> Self {
22 Self {
23 max_connections_per_ip: 10,
24 max_messages_per_hour: 100,
25 window_duration: Duration::from_secs(3600), }
27 }
28}
29
30#[derive(Debug, Clone)]
32struct ConnectionEntry {
33 count: usize,
34 first_seen: Instant,
35}
36
37#[derive(Debug, Clone)]
39struct MessageEntry {
40 count: usize,
41 window_start: Instant,
42}
43
44pub struct RateLimiter {
46 config: Arc<RwLock<RateLimitConfig>>,
47 connections: Arc<Mutex<HashMap<IpAddr, ConnectionEntry>>>,
48 messages: Arc<Mutex<HashMap<IpAddr, MessageEntry>>>,
49}
50
51impl RateLimiter {
52 pub fn new(config: RateLimitConfig) -> Self {
54 Self {
55 config: Arc::new(RwLock::new(config)),
56 connections: Arc::new(Mutex::new(HashMap::new())),
57 messages: Arc::new(Mutex::new(HashMap::new())),
58 }
59 }
60
61 pub async fn update_config(&self, new_config: RateLimitConfig) {
63 let mut config = self.config.write().await;
64 *config = new_config;
65 }
66
67 pub async fn allow_connection(&self, ip: IpAddr) -> bool {
69 let config = self.config.read().await;
70 let mut connections = self.connections.lock().await;
71
72 let now = Instant::now();
74 let window_duration = config.window_duration;
75 connections.retain(|_, entry| now.duration_since(entry.first_seen) < window_duration);
76
77 let max_connections = config.max_connections_per_ip;
79 match connections.get_mut(&ip) {
80 Some(entry) => {
81 if entry.count >= max_connections {
82 tracing::warn!("Connection rate limit exceeded for IP: {}", ip);
83 false
84 } else {
85 entry.count += 1;
86 true
87 }
88 }
89 None => {
90 connections.insert(
91 ip,
92 ConnectionEntry {
93 count: 1,
94 first_seen: now,
95 },
96 );
97 true
98 }
99 }
100 }
101
102 pub async fn release_connection(&self, ip: IpAddr) {
104 let mut connections = self.connections.lock().await;
105 if let Some(entry) = connections.get_mut(&ip) {
106 if entry.count > 0 {
107 entry.count -= 1;
108 }
109 if entry.count == 0 {
110 connections.remove(&ip);
111 }
112 }
113 }
114
115 pub async fn allow_message(&self, ip: IpAddr) -> bool {
117 let config = self.config.read().await;
118 let mut messages = self.messages.lock().await;
119
120 let now = Instant::now();
121 let window_duration = config.window_duration;
122 let max_messages = config.max_messages_per_hour;
123
124 match messages.get_mut(&ip) {
125 Some(entry) => {
126 if now.duration_since(entry.window_start) >= window_duration {
128 entry.count = 1;
129 entry.window_start = now;
130 true
131 } else if entry.count >= max_messages {
132 tracing::warn!("Message rate limit exceeded for IP: {}", ip);
133 false
134 } else {
135 entry.count += 1;
136 true
137 }
138 }
139 None => {
140 messages.insert(
141 ip,
142 MessageEntry {
143 count: 1,
144 window_start: now,
145 },
146 );
147 true
148 }
149 }
150 }
151
152 pub async fn get_connection_count(&self, ip: IpAddr) -> usize {
154 let connections = self.connections.lock().await;
155 connections.get(&ip).map(|e| e.count).unwrap_or(0)
156 }
157
158 pub async fn get_message_count(&self, ip: IpAddr) -> usize {
160 let messages = self.messages.lock().await;
161 messages.get(&ip).map(|e| e.count).unwrap_or(0)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::net::{IpAddr, Ipv4Addr};
169
170 #[tokio::test]
171 async fn test_connection_limit() {
172 let config = RateLimitConfig {
173 max_connections_per_ip: 2,
174 max_messages_per_hour: 100,
175 window_duration: Duration::from_secs(60),
176 };
177
178 let limiter = RateLimiter::new(config);
179 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
180
181 assert!(limiter.allow_connection(ip).await);
183 assert!(limiter.allow_connection(ip).await);
184
185 assert!(!limiter.allow_connection(ip).await);
187
188 limiter.release_connection(ip).await;
190
191 assert!(limiter.allow_connection(ip).await);
193 }
194
195 #[tokio::test]
196 async fn test_message_limit() {
197 let config = RateLimitConfig {
198 max_connections_per_ip: 10,
199 max_messages_per_hour: 2,
200 window_duration: Duration::from_secs(60),
201 };
202
203 let limiter = RateLimiter::new(config);
204 let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1));
205
206 assert!(limiter.allow_message(ip).await);
208 assert!(limiter.allow_message(ip).await);
209
210 assert!(!limiter.allow_message(ip).await);
212 }
213}