wae_authentication/rate_limit/
service.rs1use chrono::Utc;
4use std::{collections::HashMap, sync::Arc};
5use tokio::sync::RwLock;
6
7use crate::rate_limit::{RateLimitConfig, RateLimitError, RateLimitKey, RateLimitResult};
8
9#[derive(Debug, Clone)]
11struct RequestRecord {
12 timestamps: Vec<i64>,
14}
15
16#[derive(Debug, Clone)]
18pub struct RateLimitService {
19 config: RateLimitConfig,
20 requests: Arc<RwLock<HashMap<RateLimitKey, RequestRecord>>>,
21}
22
23impl RateLimitService {
24 pub fn new(config: RateLimitConfig) -> Self {
29 Self { config, requests: Arc::new(RwLock::new(HashMap::new())) }
30 }
31
32 pub fn default() -> Self {
34 Self::new(RateLimitConfig::default())
35 }
36
37 pub async fn check_and_record(&self, key: RateLimitKey) -> RateLimitResult<()> {
42 let now = Utc::now().timestamp();
43 let window_start = now - self.config.window_seconds as i64;
44
45 let mut requests = self.requests.write().await;
46
47 let record = requests.entry(key).or_insert_with(|| RequestRecord { timestamps: Vec::new() });
48
49 record.timestamps.retain(|&ts| ts > window_start);
50
51 if record.timestamps.len() >= self.config.max_requests as usize {
52 let oldest_ts = record.timestamps.first().copied().unwrap_or(now);
53 let retry_after = (oldest_ts + self.config.window_seconds as i64 - now).max(0) as u64;
54 return Err(RateLimitError::RateLimitExceeded { retry_after });
55 }
56
57 record.timestamps.push(now);
58
59 self.cleanup_old_records().await;
60
61 Ok(())
62 }
63
64 pub async fn reset(&self, key: &RateLimitKey) {
69 self.requests.write().await.remove(key);
70 }
71
72 pub async fn get_count(&self, key: &RateLimitKey) -> u32 {
77 let now = Utc::now().timestamp();
78 let window_start = now - self.config.window_seconds as i64;
79
80 let requests = self.requests.read().await;
81
82 requests.get(key).map(|record| record.timestamps.iter().filter(|&&ts| ts > window_start).count() as u32).unwrap_or(0)
83 }
84
85 async fn cleanup_old_records(&self) {
87 let now = Utc::now().timestamp();
88 let window_start = now - self.config.window_seconds as i64;
89
90 let mut requests = self.requests.write().await;
91 requests.retain(|_, record| {
92 record.timestamps.retain(|&ts| ts > window_start);
93 !record.timestamps.is_empty()
94 });
95 }
96
97 pub fn config(&self) -> &RateLimitConfig {
99 &self.config
100 }
101}
102
103impl Default for RateLimitService {
104 fn default() -> Self {
105 Self::new(RateLimitConfig::default())
106 }
107}