Skip to main content

wae_authentication/rate_limit/
service.rs

1//! 速率限制服务实现
2
3use chrono::Utc;
4use std::{collections::HashMap, sync::Arc};
5use tokio::sync::RwLock;
6
7use crate::rate_limit::{RateLimitConfig, RateLimitError, RateLimitKey, RateLimitResult};
8
9/// 请求记录
10#[derive(Debug, Clone)]
11struct RequestRecord {
12    /// 请求时间戳
13    timestamps: Vec<i64>,
14}
15
16/// 速率限制服务
17#[derive(Debug, Clone)]
18pub struct RateLimitService {
19    config: RateLimitConfig,
20    requests: Arc<RwLock<HashMap<RateLimitKey, RequestRecord>>>,
21}
22
23impl RateLimitService {
24    /// 创建新的速率限制服务
25    ///
26    /// # Arguments
27    /// * `config` - 速率限制配置
28    pub fn new(config: RateLimitConfig) -> Self {
29        Self { config, requests: Arc::new(RwLock::new(HashMap::new())) }
30    }
31
32    /// 使用默认配置创建速率限制服务
33    pub fn default() -> Self {
34        Self::new(RateLimitConfig::default())
35    }
36
37    /// 检查并记录请求
38    ///
39    /// # Arguments
40    /// * `key` - 速率限制键
41    pub async fn check_and_record(&self, key: RateLimitKey) -> RateLimitResult<()> {
42        let now = Utc::now().timestamp();
43        let window_start = now - self.config.window_seconds as i64;
44
45        let mut requests = self.requests.write().await;
46
47        let record = requests.entry(key).or_insert_with(|| RequestRecord { timestamps: Vec::new() });
48
49        record.timestamps.retain(|&ts| ts > window_start);
50
51        if record.timestamps.len() >= self.config.max_requests as usize {
52            let oldest_ts = record.timestamps.first().copied().unwrap_or(now);
53            let retry_after = (oldest_ts + self.config.window_seconds as i64 - now).max(0) as u64;
54            return Err(RateLimitError::RateLimitExceeded { retry_after });
55        }
56
57        record.timestamps.push(now);
58
59        self.cleanup_old_records().await;
60
61        Ok(())
62    }
63
64    /// 重置指定键的速率限制
65    ///
66    /// # Arguments
67    /// * `key` - 速率限制键
68    pub async fn reset(&self, key: &RateLimitKey) {
69        self.requests.write().await.remove(key);
70    }
71
72    /// 获取当前请求计数
73    ///
74    /// # Arguments
75    /// * `key` - 速率限制键
76    pub async fn get_count(&self, key: &RateLimitKey) -> u32 {
77        let now = Utc::now().timestamp();
78        let window_start = now - self.config.window_seconds as i64;
79
80        let requests = self.requests.read().await;
81
82        requests.get(key).map(|record| record.timestamps.iter().filter(|&&ts| ts > window_start).count() as u32).unwrap_or(0)
83    }
84
85    /// 清理过期记录
86    async fn cleanup_old_records(&self) {
87        let now = Utc::now().timestamp();
88        let window_start = now - self.config.window_seconds as i64;
89
90        let mut requests = self.requests.write().await;
91        requests.retain(|_, record| {
92            record.timestamps.retain(|&ts| ts > window_start);
93            !record.timestamps.is_empty()
94        });
95    }
96
97    /// 获取配置
98    pub fn config(&self) -> &RateLimitConfig {
99        &self.config
100    }
101}
102
103impl Default for RateLimitService {
104    fn default() -> Self {
105        Self::new(RateLimitConfig::default())
106    }
107}