sentinel_proxy/
distributed_rate_limit.rs1use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::warn;
20
21#[cfg(feature = "distributed-rate-limit")]
22use redis::aio::ConnectionManager;
23
24use sentinel_config::RedisBackendConfig;
25
26use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
27
28#[derive(Debug, Default)]
30pub struct DistributedRateLimitStats {
31 pub total_checks: AtomicU64,
33 pub allowed: AtomicU64,
35 pub limited: AtomicU64,
37 pub redis_errors: AtomicU64,
39 pub local_fallbacks: AtomicU64,
41}
42
43impl DistributedRateLimitStats {
44 pub fn record_check(&self, outcome: RateLimitOutcome) {
45 self.total_checks.fetch_add(1, Ordering::Relaxed);
46 match outcome {
47 RateLimitOutcome::Allowed => {
48 self.allowed.fetch_add(1, Ordering::Relaxed);
49 }
50 RateLimitOutcome::Limited => {
51 self.limited.fetch_add(1, Ordering::Relaxed);
52 }
53 }
54 }
55
56 pub fn record_redis_error(&self) {
57 self.redis_errors.fetch_add(1, Ordering::Relaxed);
58 }
59
60 pub fn record_local_fallback(&self) {
61 self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
62 }
63}
64
65#[cfg(feature = "distributed-rate-limit")]
67pub struct RedisRateLimiter {
68 connection: ConnectionManager,
70 config: RwLock<RedisConfig>,
72 healthy: AtomicBool,
74 pub stats: Arc<DistributedRateLimitStats>,
76}
77
78#[cfg(feature = "distributed-rate-limit")]
79#[derive(Debug, Clone)]
80struct RedisConfig {
81 key_prefix: String,
82 max_rps: u32,
83 window_secs: u64,
84 timeout: Duration,
85 fallback_local: bool,
86}
87
88#[cfg(feature = "distributed-rate-limit")]
89impl RedisRateLimiter {
90 pub async fn new(
92 backend_config: &RedisBackendConfig,
93 rate_config: &RateLimitConfig,
94 ) -> Result<Self, redis::RedisError> {
95 let client = redis::Client::open(backend_config.url.as_str())?;
96 let connection = ConnectionManager::new(client).await?;
97
98 debug!(
99 url = %backend_config.url,
100 prefix = %backend_config.key_prefix,
101 max_rps = rate_config.max_rps,
102 "Redis rate limiter initialized"
103 );
104
105 Ok(Self {
106 connection,
107 config: RwLock::new(RedisConfig {
108 key_prefix: backend_config.key_prefix.clone(),
109 max_rps: rate_config.max_rps,
110 window_secs: 1,
111 timeout: Duration::from_millis(backend_config.timeout_ms),
112 fallback_local: backend_config.fallback_local,
113 }),
114 healthy: AtomicBool::new(true),
115 stats: Arc::new(DistributedRateLimitStats::default()),
116 })
117 }
118
119 pub async fn check(&self, key: &str) -> Result<(RateLimitOutcome, i64), redis::RedisError> {
123 let config = self.config.read().clone();
124 let full_key = format!("{}{}", config.key_prefix, key);
125
126 let now = std::time::SystemTime::now()
128 .duration_since(std::time::UNIX_EPOCH)
129 .unwrap()
130 .as_millis() as f64;
131
132 let window_start = now - (config.window_secs as f64 * 1000.0);
133
134 let mut conn = self.connection.clone();
136
137 let result: Result<(i64,), _> = tokio::time::timeout(config.timeout, async {
138 redis::pipe()
139 .atomic()
140 .zrembyscore(&full_key, 0.0, window_start)
142 .ignore()
143 .zadd(&full_key, now, now.to_string())
145 .ignore()
146 .expire(&full_key, (config.window_secs * 2) as i64)
148 .ignore()
149 .zcount(&full_key, window_start, now)
151 .query_async(&mut conn)
152 .await
153 })
154 .await
155 .map_err(|_| {
156 redis::RedisError::from((redis::ErrorKind::IoError, "Redis operation timed out"))
157 })?;
158
159 let (count,) = result?;
160
161 self.healthy.store(true, Ordering::Relaxed);
162
163 let outcome = if count > config.max_rps as i64 {
164 RateLimitOutcome::Limited
165 } else {
166 RateLimitOutcome::Allowed
167 };
168
169 trace!(
170 key = key,
171 count = count,
172 max_rps = config.max_rps,
173 outcome = ?outcome,
174 "Redis rate limit check"
175 );
176
177 self.stats.record_check(outcome);
178 Ok((outcome, count))
179 }
180
181 pub fn update_config(
183 &self,
184 backend_config: &RedisBackendConfig,
185 rate_config: &RateLimitConfig,
186 ) {
187 let mut config = self.config.write();
188 config.key_prefix = backend_config.key_prefix.clone();
189 config.max_rps = rate_config.max_rps;
190 config.timeout = Duration::from_millis(backend_config.timeout_ms);
191 config.fallback_local = backend_config.fallback_local;
192 }
193
194 pub fn is_healthy(&self) -> bool {
196 self.healthy.load(Ordering::Relaxed)
197 }
198
199 pub fn mark_unhealthy(&self) {
201 self.healthy.store(false, Ordering::Relaxed);
202 self.stats.record_redis_error();
203 }
204
205 pub fn fallback_enabled(&self) -> bool {
207 self.config.read().fallback_local
208 }
209}
210
211#[cfg(not(feature = "distributed-rate-limit"))]
213pub struct RedisRateLimiter;
214
215#[cfg(not(feature = "distributed-rate-limit"))]
216impl RedisRateLimiter {
217 pub async fn new(
218 _backend_config: &RedisBackendConfig,
219 _rate_config: &RateLimitConfig,
220 ) -> Result<Self, String> {
221 Err("Distributed rate limiting requires the 'distributed-rate-limit' feature".to_string())
222 }
223}
224
225#[cfg(feature = "distributed-rate-limit")]
227pub async fn create_redis_rate_limiter(
228 backend_config: &RedisBackendConfig,
229 rate_config: &RateLimitConfig,
230) -> Option<RedisRateLimiter> {
231 match RedisRateLimiter::new(backend_config, rate_config).await {
232 Ok(limiter) => {
233 debug!(
234 url = %backend_config.url,
235 "Redis rate limiter created successfully"
236 );
237 Some(limiter)
238 }
239 Err(e) => {
240 error!(
241 error = %e,
242 url = %backend_config.url,
243 "Failed to create Redis rate limiter"
244 );
245 if backend_config.fallback_local {
246 warn!("Falling back to local rate limiting");
247 }
248 None
249 }
250 }
251}
252
253#[cfg(not(feature = "distributed-rate-limit"))]
254pub async fn create_redis_rate_limiter(
255 _backend_config: &RedisBackendConfig,
256 _rate_config: &RateLimitConfig,
257) -> Option<RedisRateLimiter> {
258 warn!(
259 "Distributed rate limiting requested but feature is disabled. Using local rate limiting."
260 );
261 None
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_stats_recording() {
270 let stats = DistributedRateLimitStats::default();
271
272 stats.record_check(RateLimitOutcome::Allowed);
273 stats.record_check(RateLimitOutcome::Allowed);
274 stats.record_check(RateLimitOutcome::Limited);
275
276 assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
277 assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
278 assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
279 }
280
281 #[test]
282 fn test_stats_redis_errors() {
283 let stats = DistributedRateLimitStats::default();
284
285 stats.record_redis_error();
286 stats.record_redis_error();
287 stats.record_local_fallback();
288
289 assert_eq!(stats.redis_errors.load(Ordering::Relaxed), 2);
290 assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
291 }
292}