1use crate::error::{LinkError, Result};
2use std::time::{Duration, Instant};
3use std::collections::VecDeque;
4
5#[derive(Debug, Clone)]
6pub struct RateLimitConfig {
7 pub max_messages_per_second: u32,
8 pub max_bytes_per_second: u64,
9 pub burst_size: u32,
10 pub window_duration: Duration,
11}
12
13impl Default for RateLimitConfig {
14 fn default() -> Self {
15 Self {
16 max_messages_per_second: 1000,
17 max_bytes_per_second: 10 * 1024 * 1024,
18 burst_size: 100,
19 window_duration: Duration::from_secs(1),
20 }
21 }
22}
23
24impl RateLimitConfig {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn with_max_messages(mut self, max: u32) -> Self {
30 self.max_messages_per_second = max;
31 self
32 }
33
34 pub fn with_max_bytes(mut self, max: u64) -> Self {
35 self.max_bytes_per_second = max;
36 self
37 }
38
39 pub fn with_burst_size(mut self, size: u32) -> Self {
40 self.burst_size = size;
41 self
42 }
43
44 pub fn with_window_duration(mut self, duration: Duration) -> Self {
45 self.window_duration = duration;
46 self
47 }
48}
49
50struct MessageRecord {
51 timestamp: Instant,
52 size: u64,
53}
54
55pub struct RateLimiter {
56 config: RateLimitConfig,
57 message_history: VecDeque<MessageRecord>,
58 byte_history: VecDeque<MessageRecord>,
59 total_messages: u64,
60 total_bytes: u64,
61 total_rejected: u64,
62}
63
64impl RateLimiter {
65 pub fn new(config: RateLimitConfig) -> Self {
66 Self {
67 config,
68 message_history: VecDeque::new(),
69 byte_history: VecDeque::new(),
70 total_messages: 0,
71 total_bytes: 0,
72 total_rejected: 0,
73 }
74 }
75
76 pub fn check_and_record(&mut self, message_size: u64) -> Result<()> {
77 let now = Instant::now();
78
79 self.cleanup_old_records(now);
80
81 let messages_in_window = self.count_messages_in_window(now);
82 let bytes_in_window = self.count_bytes_in_window(now);
83
84 if messages_in_window >= self.config.max_messages_per_second {
85 self.total_rejected += 1;
86 return Err(LinkError::RateLimitExceeded(
87 format!("Message rate limit exceeded: {} msgs/sec", self.config.max_messages_per_second)
88 ));
89 }
90
91 if bytes_in_window + message_size > self.config.max_bytes_per_second {
92 self.total_rejected += 1;
93 return Err(LinkError::RateLimitExceeded(
94 format!("Byte rate limit exceeded: {} bytes/sec", self.config.max_bytes_per_second)
95 ));
96 }
97
98 let burst_count = self.count_recent_burst(now);
99 if burst_count >= self.config.burst_size {
100 self.total_rejected += 1;
101 return Err(LinkError::RateLimitExceeded(
102 format!("Burst limit exceeded: {} msgs", self.config.burst_size)
103 ));
104 }
105
106 self.record_message(now, message_size);
107
108 Ok(())
109 }
110
111 pub fn check(&mut self, message_size: u64) -> bool {
112 self.check_and_record(message_size).is_ok()
113 }
114
115 fn record_message(&mut self, timestamp: Instant, size: u64) {
116 let record = MessageRecord {
117 timestamp,
118 size,
119 };
120
121 self.message_history.push_back(record.clone());
122 self.byte_history.push_back(record);
123
124 self.total_messages += 1;
125 self.total_bytes += size;
126 }
127
128 fn cleanup_old_records(&mut self, now: Instant) {
129 let cutoff = now - self.config.window_duration;
130
131 while let Some(record) = self.message_history.front() {
132 if record.timestamp < cutoff {
133 self.message_history.pop_front();
134 } else {
135 break;
136 }
137 }
138
139 while let Some(record) = self.byte_history.front() {
140 if record.timestamp < cutoff {
141 self.byte_history.pop_front();
142 } else {
143 break;
144 }
145 }
146 }
147
148 fn count_messages_in_window(&self, now: Instant) -> u32 {
149 let cutoff = now - self.config.window_duration;
150 self.message_history.iter()
151 .filter(|r| r.timestamp >= cutoff)
152 .count() as u32
153 }
154
155 fn count_bytes_in_window(&self, now: Instant) -> u64 {
156 let cutoff = now - self.config.window_duration;
157 self.byte_history.iter()
158 .filter(|r| r.timestamp >= cutoff)
159 .map(|r| r.size)
160 .sum()
161 }
162
163 fn count_recent_burst(&self, now: Instant) -> u32 {
164 let burst_window = Duration::from_millis(100);
165 let cutoff = now - burst_window;
166
167 self.message_history.iter()
168 .filter(|r| r.timestamp >= cutoff)
169 .count() as u32
170 }
171
172 pub fn reset(&mut self) {
173 self.message_history.clear();
174 self.byte_history.clear();
175 }
176
177 pub fn get_stats(&self) -> RateLimitStats {
178 RateLimitStats {
179 total_messages: self.total_messages,
180 total_bytes: self.total_bytes,
181 total_rejected: self.total_rejected,
182 messages_in_window: self.message_history.len() as u32,
183 bytes_in_window: self.byte_history.iter().map(|r| r.size).sum(),
184 }
185 }
186
187 pub fn get_config(&self) -> &RateLimitConfig {
188 &self.config
189 }
190
191 pub fn set_config(&mut self, config: RateLimitConfig) {
192 self.config = config;
193 }
194}
195
196impl Clone for MessageRecord {
197 fn clone(&self) -> Self {
198 Self {
199 timestamp: self.timestamp,
200 size: self.size,
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
206pub struct RateLimitStats {
207 pub total_messages: u64,
208 pub total_bytes: u64,
209 pub total_rejected: u64,
210 pub messages_in_window: u32,
211 pub bytes_in_window: u64,
212}
213
214pub struct TokenBucketRateLimiter {
215 capacity: u32,
216 tokens: u32,
217 refill_rate: u32,
218 last_refill: Instant,
219 total_messages: u64,
220 total_rejected: u64,
221}
222
223impl TokenBucketRateLimiter {
224 pub fn new(capacity: u32, refill_rate: u32) -> Self {
225 Self {
226 capacity,
227 tokens: capacity,
228 refill_rate,
229 last_refill: Instant::now(),
230 total_messages: 0,
231 total_rejected: 0,
232 }
233 }
234
235 pub fn check_and_consume(&mut self) -> Result<()> {
236 self.refill();
237
238 if self.tokens == 0 {
239 self.total_rejected += 1;
240 return Err(LinkError::RateLimitExceeded(
241 format!("Token bucket empty (capacity: {})", self.capacity)
242 ));
243 }
244
245 self.tokens -= 1;
246 self.total_messages += 1;
247
248 Ok(())
249 }
250
251 pub fn check(&mut self) -> bool {
252 self.check_and_consume().is_ok()
253 }
254
255 fn refill(&mut self) {
256 let now = Instant::now();
257 let elapsed = now.duration_since(self.last_refill);
258 let elapsed_secs = elapsed.as_secs_f64();
259
260 let tokens_to_add = (elapsed_secs * self.refill_rate as f64) as u32;
261
262 if tokens_to_add > 0 {
263 self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
264 self.last_refill = now;
265 }
266 }
267
268 pub fn reset(&mut self) {
269 self.tokens = self.capacity;
270 self.last_refill = Instant::now();
271 }
272
273 pub fn get_available_tokens(&self) -> u32 {
274 self.tokens
275 }
276
277 pub fn get_stats(&self) -> (u64, u64) {
278 (self.total_messages, self.total_rejected)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::thread;
286
287 #[test]
288 fn test_rate_limiter_basic() {
289 let config = RateLimitConfig::new()
290 .with_max_messages(10)
291 .with_max_bytes(1000);
292
293 let mut limiter = RateLimiter::new(config);
294
295 for _ in 0..10 {
296 assert!(limiter.check_and_record(50).is_ok());
297 }
298
299 assert!(limiter.check_and_record(50).is_err());
300 }
301
302 #[test]
303 fn test_rate_limiter_byte_limit() {
304 let config = RateLimitConfig::new()
305 .with_max_messages(100)
306 .with_max_bytes(500);
307
308 let mut limiter = RateLimiter::new(config);
309
310 assert!(limiter.check_and_record(300).is_ok());
311 assert!(limiter.check_and_record(300).is_err());
312 }
313
314 #[test]
315 fn test_rate_limiter_burst() {
316 let config = RateLimitConfig::new()
317 .with_max_messages(1000)
318 .with_burst_size(5);
319
320 let mut limiter = RateLimiter::new(config);
321
322 for _ in 0..5 {
323 assert!(limiter.check_and_record(100).is_ok());
324 }
325
326 assert!(limiter.check_and_record(100).is_err());
327 }
328
329 #[test]
330 fn test_rate_limiter_window() {
331 let config = RateLimitConfig::new()
332 .with_max_messages(5)
333 .with_window_duration(Duration::from_millis(100));
334
335 let mut limiter = RateLimiter::new(config);
336
337 for _ in 0..5 {
338 assert!(limiter.check_and_record(100).is_ok());
339 }
340
341 assert!(limiter.check_and_record(100).is_err());
342
343 thread::sleep(Duration::from_millis(150));
344
345 assert!(limiter.check_and_record(100).is_ok());
346 }
347
348 #[test]
349 fn test_token_bucket() {
350 let mut limiter = TokenBucketRateLimiter::new(5, 10);
351
352 for _ in 0..5 {
353 assert!(limiter.check_and_consume().is_ok());
354 }
355
356 assert!(limiter.check_and_consume().is_err());
357
358 thread::sleep(Duration::from_millis(100));
359 limiter.refill();
360
361 assert!(limiter.check_and_consume().is_ok());
362 }
363
364 #[test]
365 fn test_rate_limiter_stats() {
366 let config = RateLimitConfig::new().with_max_messages(5);
367 let mut limiter = RateLimiter::new(config);
368
369 for _ in 0..3 {
370 let _ = limiter.check_and_record(100);
371 }
372
373 for _ in 0..3 {
374 let _ = limiter.check_and_record(100);
375 }
376
377 let stats = limiter.get_stats();
378 assert_eq!(stats.total_messages, 5);
379 assert_eq!(stats.total_rejected, 1);
380 }
381}