Skip to main content

wae_resilience/
rate_limiter.rs

1//! 限流器模块
2//!
3//! 提供令牌桶和滑动窗口两种限流算法实现。
4
5use async_trait::async_trait;
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use std::{
9    collections::VecDeque,
10    sync::Arc,
11    time::{Duration, Instant},
12};
13use wae_types::WaeError;
14
15/// 限流器 trait
16#[async_trait]
17pub trait RateLimiter: Send + Sync {
18    /// 获取许可 (阻塞直到获取成功)
19    async fn acquire(&self) -> Result<(), WaeError>;
20
21    /// 尝试获取许可 (非阻塞)
22    fn try_acquire(&self) -> Result<(), WaeError>;
23
24    /// 获取当前可用许可数
25    fn available_permits(&self) -> u64;
26}
27
28/// 令牌桶配置
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TokenBucketConfig {
31    /// 桶容量
32    pub capacity: u64,
33    /// 每秒补充的令牌数
34    pub refill_rate: u64,
35}
36
37impl Default for TokenBucketConfig {
38    fn default() -> Self {
39        Self { capacity: 100, refill_rate: 10 }
40    }
41}
42
43impl TokenBucketConfig {
44    /// 创建新的令牌桶配置
45    pub fn new(capacity: u64, refill_rate: u64) -> Self {
46        Self { capacity, refill_rate }
47    }
48}
49
50/// 令牌桶限流器
51///
52/// 实现令牌桶算法,支持突发流量。
53/// 以固定速率向桶中添加令牌,请求消耗令牌,桶空时拒绝请求。
54pub struct TokenBucket {
55    /// 配置
56    config: TokenBucketConfig,
57    /// 当前令牌数
58    tokens: Arc<Mutex<(u64, Instant)>>,
59}
60
61impl TokenBucket {
62    /// 创建新的令牌桶限流器
63    pub fn new(config: TokenBucketConfig) -> Self {
64        Self { tokens: Arc::new(Mutex::new((config.capacity, Instant::now()))), config }
65    }
66
67    /// 使用默认配置创建
68    pub fn with_defaults() -> Self {
69        Self::new(TokenBucketConfig::default())
70    }
71
72    /// 补充令牌
73    fn refill(&self) {
74        let mut tokens = self.tokens.lock();
75        let now = Instant::now();
76        let elapsed = now.duration_since(tokens.1);
77        let tokens_to_add = (elapsed.as_secs_f64() * self.config.refill_rate as f64) as u64;
78
79        if tokens_to_add > 0 {
80            tokens.0 = (tokens.0 + tokens_to_add).min(self.config.capacity);
81            tokens.1 = now;
82        }
83    }
84}
85
86#[async_trait]
87impl RateLimiter for TokenBucket {
88    async fn acquire(&self) -> Result<(), WaeError> {
89        loop {
90            self.refill();
91
92            let wait_duration = {
93                let mut tokens = self.tokens.lock();
94                if tokens.0 > 0 {
95                    tokens.0 -= 1;
96                    return Ok(());
97                }
98
99                let tokens_needed = 1;
100                let wait_secs = tokens_needed as f64 / self.config.refill_rate as f64;
101                Duration::from_secs_f64(wait_secs)
102            };
103
104            tokio::time::sleep(wait_duration).await;
105        }
106    }
107
108    fn try_acquire(&self) -> Result<(), WaeError> {
109        self.refill();
110
111        let mut tokens = self.tokens.lock();
112        if tokens.0 > 0 {
113            tokens.0 -= 1;
114            Ok(())
115        }
116        else {
117            Err(WaeError::rate_limit_exceeded(self.config.capacity))
118        }
119    }
120
121    fn available_permits(&self) -> u64 {
122        self.refill();
123        let tokens = self.tokens.lock();
124        tokens.0
125    }
126}
127
128/// 滑动窗口配置
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct SlidingWindowConfig {
131    /// 时间窗口大小
132    pub window_size: Duration,
133    /// 窗口内最大请求数
134    pub max_requests: u64,
135}
136
137impl Default for SlidingWindowConfig {
138    fn default() -> Self {
139        Self { window_size: Duration::from_secs(1), max_requests: 100 }
140    }
141}
142
143impl SlidingWindowConfig {
144    /// 创建新的滑动窗口配置
145    pub fn new(window_size: Duration, max_requests: u64) -> Self {
146        Self { window_size, max_requests }
147    }
148}
149
150/// 滑动窗口限流器
151///
152/// 实现滑动窗口算法,精确控制请求速率。
153/// 记录时间窗口内的请求时间戳,超过限制时拒绝请求。
154pub struct SlidingWindow {
155    /// 配置
156    config: SlidingWindowConfig,
157    /// 请求时间戳队列
158    timestamps: Arc<Mutex<VecDeque<Instant>>>,
159}
160
161impl SlidingWindow {
162    /// 创建新的滑动窗口限流器
163    pub fn new(config: SlidingWindowConfig) -> Self {
164        Self { timestamps: Arc::new(Mutex::new(VecDeque::new())), config }
165    }
166
167    /// 使用默认配置创建
168    pub fn with_defaults() -> Self {
169        Self::new(SlidingWindowConfig::default())
170    }
171
172    /// 清理过期的时间戳
173    fn cleanup(&self) {
174        let now = Instant::now();
175        let mut timestamps = self.timestamps.lock();
176
177        while let Some(&front) = timestamps.front() {
178            if now.duration_since(front) > self.config.window_size {
179                timestamps.pop_front();
180            }
181            else {
182                break;
183            }
184        }
185    }
186
187    /// 获取当前窗口内的请求数
188    fn current_count(&self) -> u64 {
189        self.cleanup();
190        let timestamps = self.timestamps.lock();
191        timestamps.len() as u64
192    }
193}
194
195#[async_trait]
196impl RateLimiter for SlidingWindow {
197    async fn acquire(&self) -> Result<(), WaeError> {
198        loop {
199            self.cleanup();
200
201            let wait_duration = {
202                let mut timestamps = self.timestamps.lock();
203                if (timestamps.len() as u64) < self.config.max_requests {
204                    timestamps.push_back(Instant::now());
205                    return Ok(());
206                }
207
208                if let Some(&oldest) = timestamps.front() {
209                    let elapsed = Instant::now().duration_since(oldest);
210                    if elapsed < self.config.window_size { self.config.window_size - elapsed } else { Duration::ZERO }
211                }
212                else {
213                    Duration::ZERO
214                }
215            };
216
217            if wait_duration.is_zero() {
218                continue;
219            }
220
221            tokio::time::sleep(wait_duration).await;
222        }
223    }
224
225    fn try_acquire(&self) -> Result<(), WaeError> {
226        self.cleanup();
227
228        let mut timestamps = self.timestamps.lock();
229        if (timestamps.len() as u64) < self.config.max_requests {
230            timestamps.push_back(Instant::now());
231            Ok(())
232        }
233        else {
234            Err(WaeError::rate_limit_exceeded(self.config.max_requests))
235        }
236    }
237
238    fn available_permits(&self) -> u64 {
239        let current = self.current_count();
240        self.config.max_requests.saturating_sub(current)
241    }
242}