wae_resilience/
rate_limiter.rs1use async_trait::async_trait;
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use std::{
9 collections::VecDeque,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13use wae_types::WaeError;
14
15#[async_trait]
17pub trait RateLimiter: Send + Sync {
18 async fn acquire(&self) -> Result<(), WaeError>;
20
21 fn try_acquire(&self) -> Result<(), WaeError>;
23
24 fn available_permits(&self) -> u64;
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TokenBucketConfig {
31 pub capacity: u64,
33 pub refill_rate: u64,
35}
36
37impl Default for TokenBucketConfig {
38 fn default() -> Self {
39 Self { capacity: 100, refill_rate: 10 }
40 }
41}
42
43impl TokenBucketConfig {
44 pub fn new(capacity: u64, refill_rate: u64) -> Self {
46 Self { capacity, refill_rate }
47 }
48}
49
50pub struct TokenBucket {
55 config: TokenBucketConfig,
57 tokens: Arc<Mutex<(u64, Instant)>>,
59}
60
61impl TokenBucket {
62 pub fn new(config: TokenBucketConfig) -> Self {
64 Self { tokens: Arc::new(Mutex::new((config.capacity, Instant::now()))), config }
65 }
66
67 pub fn with_defaults() -> Self {
69 Self::new(TokenBucketConfig::default())
70 }
71
72 fn refill(&self) {
74 let mut tokens = self.tokens.lock();
75 let now = Instant::now();
76 let elapsed = now.duration_since(tokens.1);
77 let tokens_to_add = (elapsed.as_secs_f64() * self.config.refill_rate as f64) as u64;
78
79 if tokens_to_add > 0 {
80 tokens.0 = (tokens.0 + tokens_to_add).min(self.config.capacity);
81 tokens.1 = now;
82 }
83 }
84}
85
86#[async_trait]
87impl RateLimiter for TokenBucket {
88 async fn acquire(&self) -> Result<(), WaeError> {
89 loop {
90 self.refill();
91
92 let wait_duration = {
93 let mut tokens = self.tokens.lock();
94 if tokens.0 > 0 {
95 tokens.0 -= 1;
96 return Ok(());
97 }
98
99 let tokens_needed = 1;
100 let wait_secs = tokens_needed as f64 / self.config.refill_rate as f64;
101 Duration::from_secs_f64(wait_secs)
102 };
103
104 tokio::time::sleep(wait_duration).await;
105 }
106 }
107
108 fn try_acquire(&self) -> Result<(), WaeError> {
109 self.refill();
110
111 let mut tokens = self.tokens.lock();
112 if tokens.0 > 0 {
113 tokens.0 -= 1;
114 Ok(())
115 }
116 else {
117 Err(WaeError::rate_limit_exceeded(self.config.capacity))
118 }
119 }
120
121 fn available_permits(&self) -> u64 {
122 self.refill();
123 let tokens = self.tokens.lock();
124 tokens.0
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct SlidingWindowConfig {
131 pub window_size: Duration,
133 pub max_requests: u64,
135}
136
137impl Default for SlidingWindowConfig {
138 fn default() -> Self {
139 Self { window_size: Duration::from_secs(1), max_requests: 100 }
140 }
141}
142
143impl SlidingWindowConfig {
144 pub fn new(window_size: Duration, max_requests: u64) -> Self {
146 Self { window_size, max_requests }
147 }
148}
149
150pub struct SlidingWindow {
155 config: SlidingWindowConfig,
157 timestamps: Arc<Mutex<VecDeque<Instant>>>,
159}
160
161impl SlidingWindow {
162 pub fn new(config: SlidingWindowConfig) -> Self {
164 Self { timestamps: Arc::new(Mutex::new(VecDeque::new())), config }
165 }
166
167 pub fn with_defaults() -> Self {
169 Self::new(SlidingWindowConfig::default())
170 }
171
172 fn cleanup(&self) {
174 let now = Instant::now();
175 let mut timestamps = self.timestamps.lock();
176
177 while let Some(&front) = timestamps.front() {
178 if now.duration_since(front) > self.config.window_size {
179 timestamps.pop_front();
180 }
181 else {
182 break;
183 }
184 }
185 }
186
187 fn current_count(&self) -> u64 {
189 self.cleanup();
190 let timestamps = self.timestamps.lock();
191 timestamps.len() as u64
192 }
193}
194
195#[async_trait]
196impl RateLimiter for SlidingWindow {
197 async fn acquire(&self) -> Result<(), WaeError> {
198 loop {
199 self.cleanup();
200
201 let wait_duration = {
202 let mut timestamps = self.timestamps.lock();
203 if (timestamps.len() as u64) < self.config.max_requests {
204 timestamps.push_back(Instant::now());
205 return Ok(());
206 }
207
208 if let Some(&oldest) = timestamps.front() {
209 let elapsed = Instant::now().duration_since(oldest);
210 if elapsed < self.config.window_size { self.config.window_size - elapsed } else { Duration::ZERO }
211 }
212 else {
213 Duration::ZERO
214 }
215 };
216
217 if wait_duration.is_zero() {
218 continue;
219 }
220
221 tokio::time::sleep(wait_duration).await;
222 }
223 }
224
225 fn try_acquire(&self) -> Result<(), WaeError> {
226 self.cleanup();
227
228 let mut timestamps = self.timestamps.lock();
229 if (timestamps.len() as u64) < self.config.max_requests {
230 timestamps.push_back(Instant::now());
231 Ok(())
232 }
233 else {
234 Err(WaeError::rate_limit_exceeded(self.config.max_requests))
235 }
236 }
237
238 fn available_permits(&self) -> u64 {
239 let current = self.current_count();
240 self.config.max_requests.saturating_sub(current)
241 }
242}