sentinel_proxy/
distributed_rate_limit.rs

1//! Distributed rate limiting with Redis backend
2//!
3//! This module provides a Redis-backed rate limiter for multi-instance deployments.
4//! Uses a sliding window algorithm implemented with Redis sorted sets.
5//!
6//! # Algorithm
7//!
8//! Uses a sliding window log algorithm:
9//! 1. Store each request timestamp in a Redis sorted set
10//! 2. Remove timestamps older than the window (1 second)
11//! 3. Count remaining timestamps
12//! 4. Allow if count <= max_rps
13//!
14//! This provides accurate rate limiting across multiple instances with minimal
15//! Redis operations (single MULTI/EXEC transaction per request).
16
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::warn;
20
21#[cfg(feature = "distributed-rate-limit")]
22use redis::aio::ConnectionManager;
23
24use sentinel_config::RedisBackendConfig;
25
26use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
27
28/// Statistics for distributed rate limiting
29#[derive(Debug, Default)]
30pub struct DistributedRateLimitStats {
31    /// Total requests checked
32    pub total_checks: AtomicU64,
33    /// Requests allowed
34    pub allowed: AtomicU64,
35    /// Requests limited
36    pub limited: AtomicU64,
37    /// Redis errors (fallback to local)
38    pub redis_errors: AtomicU64,
39    /// Local fallback invocations
40    pub local_fallbacks: AtomicU64,
41}
42
43impl DistributedRateLimitStats {
44    pub fn record_check(&self, outcome: RateLimitOutcome) {
45        self.total_checks.fetch_add(1, Ordering::Relaxed);
46        match outcome {
47            RateLimitOutcome::Allowed => {
48                self.allowed.fetch_add(1, Ordering::Relaxed);
49            }
50            RateLimitOutcome::Limited => {
51                self.limited.fetch_add(1, Ordering::Relaxed);
52            }
53        }
54    }
55
56    pub fn record_redis_error(&self) {
57        self.redis_errors.fetch_add(1, Ordering::Relaxed);
58    }
59
60    pub fn record_local_fallback(&self) {
61        self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
62    }
63}
64
65/// Redis-backed distributed rate limiter
66#[cfg(feature = "distributed-rate-limit")]
67pub struct RedisRateLimiter {
68    /// Redis connection manager (handles reconnection)
69    connection: ConnectionManager,
70    /// Configuration
71    config: RwLock<RedisConfig>,
72    /// Whether Redis is currently healthy
73    healthy: AtomicBool,
74    /// Statistics
75    pub stats: Arc<DistributedRateLimitStats>,
76}
77
78#[cfg(feature = "distributed-rate-limit")]
79#[derive(Debug, Clone)]
80struct RedisConfig {
81    key_prefix: String,
82    max_rps: u32,
83    window_secs: u64,
84    timeout: Duration,
85    fallback_local: bool,
86}
87
88#[cfg(feature = "distributed-rate-limit")]
89impl RedisRateLimiter {
90    /// Create a new Redis rate limiter
91    pub async fn new(
92        backend_config: &RedisBackendConfig,
93        rate_config: &RateLimitConfig,
94    ) -> Result<Self, redis::RedisError> {
95        let client = redis::Client::open(backend_config.url.as_str())?;
96        let connection = ConnectionManager::new(client).await?;
97
98        debug!(
99            url = %backend_config.url,
100            prefix = %backend_config.key_prefix,
101            max_rps = rate_config.max_rps,
102            "Redis rate limiter initialized"
103        );
104
105        Ok(Self {
106            connection,
107            config: RwLock::new(RedisConfig {
108                key_prefix: backend_config.key_prefix.clone(),
109                max_rps: rate_config.max_rps,
110                window_secs: 1,
111                timeout: Duration::from_millis(backend_config.timeout_ms),
112                fallback_local: backend_config.fallback_local,
113            }),
114            healthy: AtomicBool::new(true),
115            stats: Arc::new(DistributedRateLimitStats::default()),
116        })
117    }
118
119    /// Check if a request should be rate limited
120    ///
121    /// Returns the outcome and the current request count in the window.
122    pub async fn check(&self, key: &str) -> Result<(RateLimitOutcome, i64), redis::RedisError> {
123        let config = self.config.read().clone();
124        let full_key = format!("{}{}", config.key_prefix, key);
125
126        // Use sliding window log algorithm with Redis sorted sets
127        let now = std::time::SystemTime::now()
128            .duration_since(std::time::UNIX_EPOCH)
129            .unwrap()
130            .as_millis() as f64;
131
132        let window_start = now - (config.window_secs as f64 * 1000.0);
133
134        // Atomic operation: remove old entries, add new entry, count entries
135        let mut conn = self.connection.clone();
136
137        let result: Result<(i64,), _> = tokio::time::timeout(config.timeout, async {
138            redis::pipe()
139                .atomic()
140                // Remove timestamps older than window
141                .zrembyscore(&full_key, 0.0, window_start)
142                .ignore()
143                // Add current timestamp with score = timestamp
144                .zadd(&full_key, now, now.to_string())
145                .ignore()
146                // Set expiration to prevent memory leaks
147                .expire(&full_key, (config.window_secs * 2) as i64)
148                .ignore()
149                // Count entries in window
150                .zcount(&full_key, window_start, now)
151                .query_async(&mut conn)
152                .await
153        })
154        .await
155        .map_err(|_| {
156            redis::RedisError::from((redis::ErrorKind::IoError, "Redis operation timed out"))
157        })?;
158
159        let (count,) = result?;
160
161        self.healthy.store(true, Ordering::Relaxed);
162
163        let outcome = if count > config.max_rps as i64 {
164            RateLimitOutcome::Limited
165        } else {
166            RateLimitOutcome::Allowed
167        };
168
169        trace!(
170            key = key,
171            count = count,
172            max_rps = config.max_rps,
173            outcome = ?outcome,
174            "Redis rate limit check"
175        );
176
177        self.stats.record_check(outcome);
178        Ok((outcome, count))
179    }
180
181    /// Update configuration
182    pub fn update_config(
183        &self,
184        backend_config: &RedisBackendConfig,
185        rate_config: &RateLimitConfig,
186    ) {
187        let mut config = self.config.write();
188        config.key_prefix = backend_config.key_prefix.clone();
189        config.max_rps = rate_config.max_rps;
190        config.timeout = Duration::from_millis(backend_config.timeout_ms);
191        config.fallback_local = backend_config.fallback_local;
192    }
193
194    /// Check if Redis is currently healthy
195    pub fn is_healthy(&self) -> bool {
196        self.healthy.load(Ordering::Relaxed)
197    }
198
199    /// Mark Redis as unhealthy (will trigger fallback)
200    pub fn mark_unhealthy(&self) {
201        self.healthy.store(false, Ordering::Relaxed);
202        self.stats.record_redis_error();
203    }
204
205    /// Check if fallback to local is enabled
206    pub fn fallback_enabled(&self) -> bool {
207        self.config.read().fallback_local
208    }
209}
210
211/// Stub for when distributed-rate-limit feature is disabled
212#[cfg(not(feature = "distributed-rate-limit"))]
213pub struct RedisRateLimiter;
214
215#[cfg(not(feature = "distributed-rate-limit"))]
216impl RedisRateLimiter {
217    pub async fn new(
218        _backend_config: &RedisBackendConfig,
219        _rate_config: &RateLimitConfig,
220    ) -> Result<Self, String> {
221        Err("Distributed rate limiting requires the 'distributed-rate-limit' feature".to_string())
222    }
223}
224
225/// Create a Redis rate limiter from configuration
226#[cfg(feature = "distributed-rate-limit")]
227pub async fn create_redis_rate_limiter(
228    backend_config: &RedisBackendConfig,
229    rate_config: &RateLimitConfig,
230) -> Option<RedisRateLimiter> {
231    match RedisRateLimiter::new(backend_config, rate_config).await {
232        Ok(limiter) => {
233            debug!(
234                url = %backend_config.url,
235                "Redis rate limiter created successfully"
236            );
237            Some(limiter)
238        }
239        Err(e) => {
240            error!(
241                error = %e,
242                url = %backend_config.url,
243                "Failed to create Redis rate limiter"
244            );
245            if backend_config.fallback_local {
246                warn!("Falling back to local rate limiting");
247            }
248            None
249        }
250    }
251}
252
253#[cfg(not(feature = "distributed-rate-limit"))]
254pub async fn create_redis_rate_limiter(
255    _backend_config: &RedisBackendConfig,
256    _rate_config: &RateLimitConfig,
257) -> Option<RedisRateLimiter> {
258    warn!(
259        "Distributed rate limiting requested but feature is disabled. Using local rate limiting."
260    );
261    None
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_stats_recording() {
270        let stats = DistributedRateLimitStats::default();
271
272        stats.record_check(RateLimitOutcome::Allowed);
273        stats.record_check(RateLimitOutcome::Allowed);
274        stats.record_check(RateLimitOutcome::Limited);
275
276        assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
277        assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
278        assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
279    }
280
281    #[test]
282    fn test_stats_redis_errors() {
283        let stats = DistributedRateLimitStats::default();
284
285        stats.record_redis_error();
286        stats.record_redis_error();
287        stats.record_local_fallback();
288
289        assert_eq!(stats.redis_errors.load(Ordering::Relaxed), 2);
290        assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
291    }
292}