sentinel_proxy/
distributed_rate_limit.rs

1//! Distributed rate limiting with Redis backend
2//!
3//! This module provides a Redis-backed rate limiter for multi-instance deployments.
4//! Uses a sliding window algorithm implemented with Redis sorted sets.
5//!
6//! # Algorithm
7//!
8//! Uses a sliding window log algorithm:
9//! 1. Store each request timestamp in a Redis sorted set
10//! 2. Remove timestamps older than the window (1 second)
11//! 3. Count remaining timestamps
12//! 4. Allow if count <= max_rps
13//!
14//! This provides accurate rate limiting across multiple instances with minimal
15//! Redis operations (single MULTI/EXEC transaction per request).
16
17use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20
21use parking_lot::RwLock;
22use tracing::{debug, error, trace, warn};
23
24#[cfg(feature = "distributed-rate-limit")]
25use redis::aio::ConnectionManager;
26
27use sentinel_config::RedisBackendConfig;
28
29use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
30
31/// Statistics for distributed rate limiting
32#[derive(Debug, Default)]
33pub struct DistributedRateLimitStats {
34    /// Total requests checked
35    pub total_checks: AtomicU64,
36    /// Requests allowed
37    pub allowed: AtomicU64,
38    /// Requests limited
39    pub limited: AtomicU64,
40    /// Redis errors (fallback to local)
41    pub redis_errors: AtomicU64,
42    /// Local fallback invocations
43    pub local_fallbacks: AtomicU64,
44}
45
46impl DistributedRateLimitStats {
47    pub fn record_check(&self, outcome: RateLimitOutcome) {
48        self.total_checks.fetch_add(1, Ordering::Relaxed);
49        match outcome {
50            RateLimitOutcome::Allowed => {
51                self.allowed.fetch_add(1, Ordering::Relaxed);
52            }
53            RateLimitOutcome::Limited => {
54                self.limited.fetch_add(1, Ordering::Relaxed);
55            }
56        }
57    }
58
59    pub fn record_redis_error(&self) {
60        self.redis_errors.fetch_add(1, Ordering::Relaxed);
61    }
62
63    pub fn record_local_fallback(&self) {
64        self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
65    }
66}
67
68/// Redis-backed distributed rate limiter
69#[cfg(feature = "distributed-rate-limit")]
70pub struct RedisRateLimiter {
71    /// Redis connection manager (handles reconnection)
72    connection: ConnectionManager,
73    /// Configuration
74    config: RwLock<RedisConfig>,
75    /// Whether Redis is currently healthy
76    healthy: AtomicBool,
77    /// Statistics
78    pub stats: Arc<DistributedRateLimitStats>,
79}
80
81#[cfg(feature = "distributed-rate-limit")]
82#[derive(Debug, Clone)]
83struct RedisConfig {
84    key_prefix: String,
85    max_rps: u32,
86    window_secs: u64,
87    timeout: Duration,
88    fallback_local: bool,
89}
90
91#[cfg(feature = "distributed-rate-limit")]
92impl RedisRateLimiter {
93    /// Create a new Redis rate limiter
94    pub async fn new(
95        backend_config: &RedisBackendConfig,
96        rate_config: &RateLimitConfig,
97    ) -> Result<Self, redis::RedisError> {
98        let client = redis::Client::open(backend_config.url.as_str())?;
99        let connection = ConnectionManager::new(client).await?;
100
101        debug!(
102            url = %backend_config.url,
103            prefix = %backend_config.key_prefix,
104            max_rps = rate_config.max_rps,
105            "Redis rate limiter initialized"
106        );
107
108        Ok(Self {
109            connection,
110            config: RwLock::new(RedisConfig {
111                key_prefix: backend_config.key_prefix.clone(),
112                max_rps: rate_config.max_rps,
113                window_secs: 1,
114                timeout: Duration::from_millis(backend_config.timeout_ms),
115                fallback_local: backend_config.fallback_local,
116            }),
117            healthy: AtomicBool::new(true),
118            stats: Arc::new(DistributedRateLimitStats::default()),
119        })
120    }
121
122    /// Check if a request should be rate limited
123    ///
124    /// Returns the outcome and the current request count in the window.
125    pub async fn check(&self, key: &str) -> Result<(RateLimitOutcome, i64), redis::RedisError> {
126        let config = self.config.read().clone();
127        let full_key = format!("{}{}", config.key_prefix, key);
128
129        // Use sliding window log algorithm with Redis sorted sets
130        let now = std::time::SystemTime::now()
131            .duration_since(std::time::UNIX_EPOCH)
132            .unwrap()
133            .as_millis() as f64;
134
135        let window_start = now - (config.window_secs as f64 * 1000.0);
136
137        // Atomic operation: remove old entries, add new entry, count entries
138        let mut conn = self.connection.clone();
139
140        let result: Result<(i64,), _> = tokio::time::timeout(config.timeout, async {
141            redis::pipe()
142                .atomic()
143                // Remove timestamps older than window
144                .zrembyscore(&full_key, 0.0, window_start)
145                .ignore()
146                // Add current timestamp with score = timestamp
147                .zadd(&full_key, now, now.to_string())
148                .ignore()
149                // Set expiration to prevent memory leaks
150                .expire(&full_key, (config.window_secs * 2) as i64)
151                .ignore()
152                // Count entries in window
153                .zcount(&full_key, window_start, now)
154                .query_async(&mut conn)
155                .await
156        })
157        .await
158        .map_err(|_| {
159            redis::RedisError::from((redis::ErrorKind::IoError, "Redis operation timed out"))
160        })?;
161
162        let (count,) = result?;
163
164        self.healthy.store(true, Ordering::Relaxed);
165
166        let outcome = if count > config.max_rps as i64 {
167            RateLimitOutcome::Limited
168        } else {
169            RateLimitOutcome::Allowed
170        };
171
172        trace!(
173            key = key,
174            count = count,
175            max_rps = config.max_rps,
176            outcome = ?outcome,
177            "Redis rate limit check"
178        );
179
180        self.stats.record_check(outcome);
181        Ok((outcome, count))
182    }
183
184    /// Update configuration
185    pub fn update_config(
186        &self,
187        backend_config: &RedisBackendConfig,
188        rate_config: &RateLimitConfig,
189    ) {
190        let mut config = self.config.write();
191        config.key_prefix = backend_config.key_prefix.clone();
192        config.max_rps = rate_config.max_rps;
193        config.timeout = Duration::from_millis(backend_config.timeout_ms);
194        config.fallback_local = backend_config.fallback_local;
195    }
196
197    /// Check if Redis is currently healthy
198    pub fn is_healthy(&self) -> bool {
199        self.healthy.load(Ordering::Relaxed)
200    }
201
202    /// Mark Redis as unhealthy (will trigger fallback)
203    pub fn mark_unhealthy(&self) {
204        self.healthy.store(false, Ordering::Relaxed);
205        self.stats.record_redis_error();
206    }
207
208    /// Check if fallback to local is enabled
209    pub fn fallback_enabled(&self) -> bool {
210        self.config.read().fallback_local
211    }
212}
213
214/// Stub for when distributed-rate-limit feature is disabled
215#[cfg(not(feature = "distributed-rate-limit"))]
216pub struct RedisRateLimiter;
217
218#[cfg(not(feature = "distributed-rate-limit"))]
219impl RedisRateLimiter {
220    pub async fn new(
221        _backend_config: &RedisBackendConfig,
222        _rate_config: &RateLimitConfig,
223    ) -> Result<Self, String> {
224        Err("Distributed rate limiting requires the 'distributed-rate-limit' feature".to_string())
225    }
226}
227
228/// Create a Redis rate limiter from configuration
229#[cfg(feature = "distributed-rate-limit")]
230pub async fn create_redis_rate_limiter(
231    backend_config: &RedisBackendConfig,
232    rate_config: &RateLimitConfig,
233) -> Option<RedisRateLimiter> {
234    match RedisRateLimiter::new(backend_config, rate_config).await {
235        Ok(limiter) => {
236            debug!(
237                url = %backend_config.url,
238                "Redis rate limiter created successfully"
239            );
240            Some(limiter)
241        }
242        Err(e) => {
243            error!(
244                error = %e,
245                url = %backend_config.url,
246                "Failed to create Redis rate limiter"
247            );
248            if backend_config.fallback_local {
249                warn!("Falling back to local rate limiting");
250            }
251            None
252        }
253    }
254}
255
256#[cfg(not(feature = "distributed-rate-limit"))]
257pub async fn create_redis_rate_limiter(
258    _backend_config: &RedisBackendConfig,
259    _rate_config: &RateLimitConfig,
260) -> Option<RedisRateLimiter> {
261    warn!(
262        "Distributed rate limiting requested but feature is disabled. Using local rate limiting."
263    );
264    None
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_stats_recording() {
273        let stats = DistributedRateLimitStats::default();
274
275        stats.record_check(RateLimitOutcome::Allowed);
276        stats.record_check(RateLimitOutcome::Allowed);
277        stats.record_check(RateLimitOutcome::Limited);
278
279        assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
280        assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
281        assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
282    }
283
284    #[test]
285    fn test_stats_redis_errors() {
286        let stats = DistributedRateLimitStats::default();
287
288        stats.record_redis_error();
289        stats.record_redis_error();
290        stats.record_local_fallback();
291
292        assert_eq!(stats.redis_errors.load(Ordering::Relaxed), 2);
293        assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
294    }
295}