sentinel_proxy/
memcached_rate_limit.rs

1//! Distributed rate limiting with Memcached backend
2//!
3//! This module provides a Memcached-backed rate limiter for multi-instance deployments.
4//! Uses a counter-based sliding window algorithm.
5//!
6//! # Algorithm
7//!
8//! Uses a fixed window counter algorithm with Memcached:
9//! 1. Generate a time-windowed key (current second)
10//! 2. Increment the counter atomically
11//! 3. Allow if count <= max_rps
12//!
13//! Note: This is slightly less accurate than Redis sorted sets but more efficient
14//! for Memcached's simpler data model.
15
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20#[cfg(feature = "distributed-rate-limit-memcached")]
21use async_memcached::AsciiProtocol;
22use parking_lot::RwLock;
23use tracing::{debug, error, trace, warn};
24
25use sentinel_config::MemcachedBackendConfig;
26
27use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
28
29/// Statistics for Memcached-based distributed rate limiting
30#[derive(Debug, Default)]
31pub struct MemcachedRateLimitStats {
32    /// Total requests checked
33    pub total_checks: AtomicU64,
34    /// Requests allowed
35    pub allowed: AtomicU64,
36    /// Requests limited
37    pub limited: AtomicU64,
38    /// Memcached errors (fallback to local)
39    pub memcached_errors: AtomicU64,
40    /// Local fallback invocations
41    pub local_fallbacks: AtomicU64,
42}
43
44impl MemcachedRateLimitStats {
45    pub fn record_check(&self, outcome: RateLimitOutcome) {
46        self.total_checks.fetch_add(1, Ordering::Relaxed);
47        match outcome {
48            RateLimitOutcome::Allowed => {
49                self.allowed.fetch_add(1, Ordering::Relaxed);
50            }
51            RateLimitOutcome::Limited => {
52                self.limited.fetch_add(1, Ordering::Relaxed);
53            }
54        }
55    }
56
57    pub fn record_memcached_error(&self) {
58        self.memcached_errors.fetch_add(1, Ordering::Relaxed);
59    }
60
61    pub fn record_local_fallback(&self) {
62        self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
63    }
64}
65
66/// Memcached-backed distributed rate limiter
67#[cfg(feature = "distributed-rate-limit-memcached")]
68pub struct MemcachedRateLimiter {
69    /// Memcached client
70    client: RwLock<async_memcached::Client>,
71    /// Configuration
72    config: RwLock<MemcachedConfig>,
73    /// Whether Memcached is currently healthy
74    healthy: AtomicBool,
75    /// Statistics
76    pub stats: Arc<MemcachedRateLimitStats>,
77}
78
79#[cfg(feature = "distributed-rate-limit-memcached")]
80#[derive(Debug, Clone)]
81struct MemcachedConfig {
82    key_prefix: String,
83    max_rps: u32,
84    window_secs: u64,
85    timeout: Duration,
86    fallback_local: bool,
87    ttl_secs: u32,
88}
89
90#[cfg(feature = "distributed-rate-limit-memcached")]
91impl MemcachedRateLimiter {
92    /// Create a new Memcached rate limiter
93    pub async fn new(
94        backend_config: &MemcachedBackendConfig,
95        rate_config: &RateLimitConfig,
96    ) -> Result<Self, async_memcached::Error> {
97        // Parse the URL to get host:port
98        let addr = backend_config
99            .url
100            .trim_start_matches("memcache://")
101            .trim_start_matches("memcached://");
102
103        let client = async_memcached::Client::new(addr).await?;
104
105        debug!(
106            url = %backend_config.url,
107            prefix = %backend_config.key_prefix,
108            max_rps = rate_config.max_rps,
109            "Memcached rate limiter initialized"
110        );
111
112        Ok(Self {
113            client: RwLock::new(client),
114            config: RwLock::new(MemcachedConfig {
115                key_prefix: backend_config.key_prefix.clone(),
116                max_rps: rate_config.max_rps,
117                window_secs: 1,
118                timeout: Duration::from_millis(backend_config.timeout_ms),
119                fallback_local: backend_config.fallback_local,
120                ttl_secs: backend_config.ttl_secs,
121            }),
122            healthy: AtomicBool::new(true),
123            stats: Arc::new(MemcachedRateLimitStats::default()),
124        })
125    }
126
127    /// Check if a request should be rate limited
128    ///
129    /// Returns the outcome and the current request count in the window.
130    pub async fn check(
131        &self,
132        key: &str,
133    ) -> Result<(RateLimitOutcome, u64), async_memcached::Error> {
134        let config = self.config.read().clone();
135
136        // Generate time-windowed key
137        let now = std::time::SystemTime::now()
138            .duration_since(std::time::UNIX_EPOCH)
139            .unwrap()
140            .as_secs();
141        let window_key = format!("{}{}:{}", config.key_prefix, key, now);
142
143        // Increment counter atomically
144        let mut client = self.client.write();
145
146        let result = tokio::time::timeout(config.timeout, async {
147            // Try to increment; if key doesn't exist, it will return an error
148            match client.increment(&window_key, 1).await {
149                Ok(count) => Ok(count),
150                Err(async_memcached::Error::Protocol(async_memcached::Status::NotFound)) => {
151                    // Key doesn't exist, set it to 1 with TTL
152                    client
153                        .set(&window_key, &b"1"[..], Some(config.ttl_secs as i64), None)
154                        .await
155                        .map(|_| 1u64)
156                }
157                Err(e) => Err(e),
158            }
159        })
160        .await
161        .map_err(|_| {
162            async_memcached::Error::Io(std::io::Error::new(
163                std::io::ErrorKind::TimedOut,
164                "Memcached operation timed out",
165            ))
166        })??;
167
168        self.healthy.store(true, Ordering::Relaxed);
169
170        let outcome = if result > config.max_rps as u64 {
171            RateLimitOutcome::Limited
172        } else {
173            RateLimitOutcome::Allowed
174        };
175
176        trace!(
177            key = key,
178            count = result,
179            max_rps = config.max_rps,
180            outcome = ?outcome,
181            "Memcached rate limit check"
182        );
183
184        self.stats.record_check(outcome);
185        Ok((outcome, result))
186    }
187
188    /// Update configuration
189    pub fn update_config(
190        &self,
191        backend_config: &MemcachedBackendConfig,
192        rate_config: &RateLimitConfig,
193    ) {
194        let mut config = self.config.write();
195        config.key_prefix = backend_config.key_prefix.clone();
196        config.max_rps = rate_config.max_rps;
197        config.timeout = Duration::from_millis(backend_config.timeout_ms);
198        config.fallback_local = backend_config.fallback_local;
199        config.ttl_secs = backend_config.ttl_secs;
200    }
201
202    /// Check if Memcached is currently healthy
203    pub fn is_healthy(&self) -> bool {
204        self.healthy.load(Ordering::Relaxed)
205    }
206
207    /// Mark Memcached as unhealthy (will trigger fallback)
208    pub fn mark_unhealthy(&self) {
209        self.healthy.store(false, Ordering::Relaxed);
210        self.stats.record_memcached_error();
211    }
212
213    /// Check if fallback to local is enabled
214    pub fn fallback_enabled(&self) -> bool {
215        self.config.read().fallback_local
216    }
217}
218
219/// Stub for when distributed-rate-limit-memcached feature is disabled
220#[cfg(not(feature = "distributed-rate-limit-memcached"))]
221pub struct MemcachedRateLimiter;
222
223#[cfg(not(feature = "distributed-rate-limit-memcached"))]
224impl MemcachedRateLimiter {
225    pub async fn new(
226        _backend_config: &MemcachedBackendConfig,
227        _rate_config: &RateLimitConfig,
228    ) -> Result<Self, String> {
229        Err(
230            "Memcached rate limiting requires the 'distributed-rate-limit-memcached' feature"
231                .to_string(),
232        )
233    }
234}
235
236/// Create a Memcached rate limiter from configuration
237#[cfg(feature = "distributed-rate-limit-memcached")]
238pub async fn create_memcached_rate_limiter(
239    backend_config: &MemcachedBackendConfig,
240    rate_config: &RateLimitConfig,
241) -> Option<MemcachedRateLimiter> {
242    match MemcachedRateLimiter::new(backend_config, rate_config).await {
243        Ok(limiter) => {
244            debug!(
245                url = %backend_config.url,
246                "Memcached rate limiter created successfully"
247            );
248            Some(limiter)
249        }
250        Err(e) => {
251            error!(
252                error = %e,
253                url = %backend_config.url,
254                "Failed to create Memcached rate limiter"
255            );
256            if backend_config.fallback_local {
257                warn!("Falling back to local rate limiting");
258            }
259            None
260        }
261    }
262}
263
264#[cfg(not(feature = "distributed-rate-limit-memcached"))]
265pub async fn create_memcached_rate_limiter(
266    _backend_config: &MemcachedBackendConfig,
267    _rate_config: &RateLimitConfig,
268) -> Option<MemcachedRateLimiter> {
269    warn!(
270        "Memcached rate limiting requested but feature is disabled. Using local rate limiting."
271    );
272    None
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_stats_recording() {
281        let stats = MemcachedRateLimitStats::default();
282
283        stats.record_check(RateLimitOutcome::Allowed);
284        stats.record_check(RateLimitOutcome::Allowed);
285        stats.record_check(RateLimitOutcome::Limited);
286
287        assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
288        assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
289        assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
290    }
291
292    #[test]
293    fn test_stats_memcached_errors() {
294        let stats = MemcachedRateLimitStats::default();
295
296        stats.record_memcached_error();
297        stats.record_memcached_error();
298        stats.record_local_fallback();
299
300        assert_eq!(stats.memcached_errors.load(Ordering::Relaxed), 2);
301        assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
302    }
303}