1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20
21use parking_lot::RwLock;
22use tracing::{debug, error, trace, warn};
23
24#[cfg(feature = "distributed-rate-limit")]
25use redis::aio::ConnectionManager;
26
27use sentinel_config::RedisBackendConfig;
28
29use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
30
31#[derive(Debug, Default)]
33pub struct DistributedRateLimitStats {
34 pub total_checks: AtomicU64,
36 pub allowed: AtomicU64,
38 pub limited: AtomicU64,
40 pub redis_errors: AtomicU64,
42 pub local_fallbacks: AtomicU64,
44}
45
46impl DistributedRateLimitStats {
47 pub fn record_check(&self, outcome: RateLimitOutcome) {
48 self.total_checks.fetch_add(1, Ordering::Relaxed);
49 match outcome {
50 RateLimitOutcome::Allowed => {
51 self.allowed.fetch_add(1, Ordering::Relaxed);
52 }
53 RateLimitOutcome::Limited => {
54 self.limited.fetch_add(1, Ordering::Relaxed);
55 }
56 }
57 }
58
59 pub fn record_redis_error(&self) {
60 self.redis_errors.fetch_add(1, Ordering::Relaxed);
61 }
62
63 pub fn record_local_fallback(&self) {
64 self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
65 }
66}
67
68#[cfg(feature = "distributed-rate-limit")]
70pub struct RedisRateLimiter {
71 connection: ConnectionManager,
73 config: RwLock<RedisConfig>,
75 healthy: AtomicBool,
77 pub stats: Arc<DistributedRateLimitStats>,
79}
80
81#[cfg(feature = "distributed-rate-limit")]
82#[derive(Debug, Clone)]
83struct RedisConfig {
84 key_prefix: String,
85 max_rps: u32,
86 window_secs: u64,
87 timeout: Duration,
88 fallback_local: bool,
89}
90
91#[cfg(feature = "distributed-rate-limit")]
92impl RedisRateLimiter {
93 pub async fn new(
95 backend_config: &RedisBackendConfig,
96 rate_config: &RateLimitConfig,
97 ) -> Result<Self, redis::RedisError> {
98 let client = redis::Client::open(backend_config.url.as_str())?;
99 let connection = ConnectionManager::new(client).await?;
100
101 debug!(
102 url = %backend_config.url,
103 prefix = %backend_config.key_prefix,
104 max_rps = rate_config.max_rps,
105 "Redis rate limiter initialized"
106 );
107
108 Ok(Self {
109 connection,
110 config: RwLock::new(RedisConfig {
111 key_prefix: backend_config.key_prefix.clone(),
112 max_rps: rate_config.max_rps,
113 window_secs: 1,
114 timeout: Duration::from_millis(backend_config.timeout_ms),
115 fallback_local: backend_config.fallback_local,
116 }),
117 healthy: AtomicBool::new(true),
118 stats: Arc::new(DistributedRateLimitStats::default()),
119 })
120 }
121
122 pub async fn check(&self, key: &str) -> Result<(RateLimitOutcome, i64), redis::RedisError> {
126 let config = self.config.read().clone();
127 let full_key = format!("{}{}", config.key_prefix, key);
128
129 let now = std::time::SystemTime::now()
131 .duration_since(std::time::UNIX_EPOCH)
132 .unwrap()
133 .as_millis() as f64;
134
135 let window_start = now - (config.window_secs as f64 * 1000.0);
136
137 let mut conn = self.connection.clone();
139
140 let result: Result<(i64,), _> = tokio::time::timeout(config.timeout, async {
141 redis::pipe()
142 .atomic()
143 .zrembyscore(&full_key, 0.0, window_start)
145 .ignore()
146 .zadd(&full_key, now, now.to_string())
148 .ignore()
149 .expire(&full_key, (config.window_secs * 2) as i64)
151 .ignore()
152 .zcount(&full_key, window_start, now)
154 .query_async(&mut conn)
155 .await
156 })
157 .await
158 .map_err(|_| {
159 redis::RedisError::from((redis::ErrorKind::IoError, "Redis operation timed out"))
160 })?;
161
162 let (count,) = result?;
163
164 self.healthy.store(true, Ordering::Relaxed);
165
166 let outcome = if count > config.max_rps as i64 {
167 RateLimitOutcome::Limited
168 } else {
169 RateLimitOutcome::Allowed
170 };
171
172 trace!(
173 key = key,
174 count = count,
175 max_rps = config.max_rps,
176 outcome = ?outcome,
177 "Redis rate limit check"
178 );
179
180 self.stats.record_check(outcome);
181 Ok((outcome, count))
182 }
183
184 pub fn update_config(
186 &self,
187 backend_config: &RedisBackendConfig,
188 rate_config: &RateLimitConfig,
189 ) {
190 let mut config = self.config.write();
191 config.key_prefix = backend_config.key_prefix.clone();
192 config.max_rps = rate_config.max_rps;
193 config.timeout = Duration::from_millis(backend_config.timeout_ms);
194 config.fallback_local = backend_config.fallback_local;
195 }
196
197 pub fn is_healthy(&self) -> bool {
199 self.healthy.load(Ordering::Relaxed)
200 }
201
202 pub fn mark_unhealthy(&self) {
204 self.healthy.store(false, Ordering::Relaxed);
205 self.stats.record_redis_error();
206 }
207
208 pub fn fallback_enabled(&self) -> bool {
210 self.config.read().fallback_local
211 }
212}
213
214#[cfg(not(feature = "distributed-rate-limit"))]
216pub struct RedisRateLimiter;
217
218#[cfg(not(feature = "distributed-rate-limit"))]
219impl RedisRateLimiter {
220 pub async fn new(
221 _backend_config: &RedisBackendConfig,
222 _rate_config: &RateLimitConfig,
223 ) -> Result<Self, String> {
224 Err("Distributed rate limiting requires the 'distributed-rate-limit' feature".to_string())
225 }
226}
227
228#[cfg(feature = "distributed-rate-limit")]
230pub async fn create_redis_rate_limiter(
231 backend_config: &RedisBackendConfig,
232 rate_config: &RateLimitConfig,
233) -> Option<RedisRateLimiter> {
234 match RedisRateLimiter::new(backend_config, rate_config).await {
235 Ok(limiter) => {
236 debug!(
237 url = %backend_config.url,
238 "Redis rate limiter created successfully"
239 );
240 Some(limiter)
241 }
242 Err(e) => {
243 error!(
244 error = %e,
245 url = %backend_config.url,
246 "Failed to create Redis rate limiter"
247 );
248 if backend_config.fallback_local {
249 warn!("Falling back to local rate limiting");
250 }
251 None
252 }
253 }
254}
255
256#[cfg(not(feature = "distributed-rate-limit"))]
257pub async fn create_redis_rate_limiter(
258 _backend_config: &RedisBackendConfig,
259 _rate_config: &RateLimitConfig,
260) -> Option<RedisRateLimiter> {
261 warn!(
262 "Distributed rate limiting requested but feature is disabled. Using local rate limiting."
263 );
264 None
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_stats_recording() {
273 let stats = DistributedRateLimitStats::default();
274
275 stats.record_check(RateLimitOutcome::Allowed);
276 stats.record_check(RateLimitOutcome::Allowed);
277 stats.record_check(RateLimitOutcome::Limited);
278
279 assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
280 assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
281 assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
282 }
283
284 #[test]
285 fn test_stats_redis_errors() {
286 let stats = DistributedRateLimitStats::default();
287
288 stats.record_redis_error();
289 stats.record_redis_error();
290 stats.record_local_fallback();
291
292 assert_eq!(stats.redis_errors.load(Ordering::Relaxed), 2);
293 assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
294 }
295}