1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20#[cfg(feature = "distributed-rate-limit-memcached")]
21use async_memcached::AsciiProtocol;
22use parking_lot::RwLock;
23use tracing::{debug, error, trace, warn};
24
25use sentinel_config::MemcachedBackendConfig;
26
27use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
28
29#[derive(Debug, Default)]
31pub struct MemcachedRateLimitStats {
32 pub total_checks: AtomicU64,
34 pub allowed: AtomicU64,
36 pub limited: AtomicU64,
38 pub memcached_errors: AtomicU64,
40 pub local_fallbacks: AtomicU64,
42}
43
44impl MemcachedRateLimitStats {
45 pub fn record_check(&self, outcome: RateLimitOutcome) {
46 self.total_checks.fetch_add(1, Ordering::Relaxed);
47 match outcome {
48 RateLimitOutcome::Allowed => {
49 self.allowed.fetch_add(1, Ordering::Relaxed);
50 }
51 RateLimitOutcome::Limited => {
52 self.limited.fetch_add(1, Ordering::Relaxed);
53 }
54 }
55 }
56
57 pub fn record_memcached_error(&self) {
58 self.memcached_errors.fetch_add(1, Ordering::Relaxed);
59 }
60
61 pub fn record_local_fallback(&self) {
62 self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
63 }
64}
65
66#[cfg(feature = "distributed-rate-limit-memcached")]
68pub struct MemcachedRateLimiter {
69 client: RwLock<async_memcached::Client>,
71 config: RwLock<MemcachedConfig>,
73 healthy: AtomicBool,
75 pub stats: Arc<MemcachedRateLimitStats>,
77}
78
79#[cfg(feature = "distributed-rate-limit-memcached")]
80#[derive(Debug, Clone)]
81struct MemcachedConfig {
82 key_prefix: String,
83 max_rps: u32,
84 window_secs: u64,
85 timeout: Duration,
86 fallback_local: bool,
87 ttl_secs: u32,
88}
89
90#[cfg(feature = "distributed-rate-limit-memcached")]
91impl MemcachedRateLimiter {
92 pub async fn new(
94 backend_config: &MemcachedBackendConfig,
95 rate_config: &RateLimitConfig,
96 ) -> Result<Self, async_memcached::Error> {
97 let addr = backend_config
99 .url
100 .trim_start_matches("memcache://")
101 .trim_start_matches("memcached://");
102
103 let client = async_memcached::Client::new(addr).await?;
104
105 debug!(
106 url = %backend_config.url,
107 prefix = %backend_config.key_prefix,
108 max_rps = rate_config.max_rps,
109 "Memcached rate limiter initialized"
110 );
111
112 Ok(Self {
113 client: RwLock::new(client),
114 config: RwLock::new(MemcachedConfig {
115 key_prefix: backend_config.key_prefix.clone(),
116 max_rps: rate_config.max_rps,
117 window_secs: 1,
118 timeout: Duration::from_millis(backend_config.timeout_ms),
119 fallback_local: backend_config.fallback_local,
120 ttl_secs: backend_config.ttl_secs,
121 }),
122 healthy: AtomicBool::new(true),
123 stats: Arc::new(MemcachedRateLimitStats::default()),
124 })
125 }
126
127 pub async fn check(
131 &self,
132 key: &str,
133 ) -> Result<(RateLimitOutcome, u64), async_memcached::Error> {
134 let config = self.config.read().clone();
135
136 let now = std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .unwrap()
140 .as_secs();
141 let window_key = format!("{}{}:{}", config.key_prefix, key, now);
142
143 #[allow(clippy::await_holding_lock)]
147 let result = tokio::time::timeout(config.timeout, async {
148 let mut client = self.client.write();
149 match client.increment(&window_key, 1).await {
151 Ok(count) => Ok(count),
152 Err(async_memcached::Error::Protocol(async_memcached::Status::NotFound)) => {
153 client
155 .set(&window_key, &b"1"[..], Some(config.ttl_secs as i64), None)
156 .await
157 .map(|_| 1u64)
158 }
159 Err(e) => Err(e),
160 }
161 })
162 .await
163 .map_err(|_| {
164 async_memcached::Error::Io(std::io::Error::new(
165 std::io::ErrorKind::TimedOut,
166 "Memcached operation timed out",
167 ))
168 })??;
169
170 self.healthy.store(true, Ordering::Relaxed);
171
172 let outcome = if result > config.max_rps as u64 {
173 RateLimitOutcome::Limited
174 } else {
175 RateLimitOutcome::Allowed
176 };
177
178 trace!(
179 key = key,
180 count = result,
181 max_rps = config.max_rps,
182 outcome = ?outcome,
183 "Memcached rate limit check"
184 );
185
186 self.stats.record_check(outcome);
187 Ok((outcome, result))
188 }
189
190 pub fn update_config(
192 &self,
193 backend_config: &MemcachedBackendConfig,
194 rate_config: &RateLimitConfig,
195 ) {
196 let mut config = self.config.write();
197 config.key_prefix = backend_config.key_prefix.clone();
198 config.max_rps = rate_config.max_rps;
199 config.timeout = Duration::from_millis(backend_config.timeout_ms);
200 config.fallback_local = backend_config.fallback_local;
201 config.ttl_secs = backend_config.ttl_secs;
202 }
203
204 pub fn is_healthy(&self) -> bool {
206 self.healthy.load(Ordering::Relaxed)
207 }
208
209 pub fn mark_unhealthy(&self) {
211 self.healthy.store(false, Ordering::Relaxed);
212 self.stats.record_memcached_error();
213 }
214
215 pub fn fallback_enabled(&self) -> bool {
217 self.config.read().fallback_local
218 }
219}
220
221#[cfg(not(feature = "distributed-rate-limit-memcached"))]
223pub struct MemcachedRateLimiter;
224
225#[cfg(not(feature = "distributed-rate-limit-memcached"))]
226impl MemcachedRateLimiter {
227 pub async fn new(
228 _backend_config: &MemcachedBackendConfig,
229 _rate_config: &RateLimitConfig,
230 ) -> Result<Self, String> {
231 Err(
232 "Memcached rate limiting requires the 'distributed-rate-limit-memcached' feature"
233 .to_string(),
234 )
235 }
236}
237
238#[cfg(feature = "distributed-rate-limit-memcached")]
240pub async fn create_memcached_rate_limiter(
241 backend_config: &MemcachedBackendConfig,
242 rate_config: &RateLimitConfig,
243) -> Option<MemcachedRateLimiter> {
244 match MemcachedRateLimiter::new(backend_config, rate_config).await {
245 Ok(limiter) => {
246 debug!(
247 url = %backend_config.url,
248 "Memcached rate limiter created successfully"
249 );
250 Some(limiter)
251 }
252 Err(e) => {
253 error!(
254 error = %e,
255 url = %backend_config.url,
256 "Failed to create Memcached rate limiter"
257 );
258 if backend_config.fallback_local {
259 warn!("Falling back to local rate limiting");
260 }
261 None
262 }
263 }
264}
265
266#[cfg(not(feature = "distributed-rate-limit-memcached"))]
267pub async fn create_memcached_rate_limiter(
268 _backend_config: &MemcachedBackendConfig,
269 _rate_config: &RateLimitConfig,
270) -> Option<MemcachedRateLimiter> {
271 warn!("Memcached rate limiting requested but feature is disabled. Using local rate limiting.");
272 None
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_stats_recording() {
281 let stats = MemcachedRateLimitStats::default();
282
283 stats.record_check(RateLimitOutcome::Allowed);
284 stats.record_check(RateLimitOutcome::Allowed);
285 stats.record_check(RateLimitOutcome::Limited);
286
287 assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
288 assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
289 assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
290 }
291
292 #[test]
293 fn test_stats_memcached_errors() {
294 let stats = MemcachedRateLimitStats::default();
295
296 stats.record_memcached_error();
297 stats.record_memcached_error();
298 stats.record_local_fallback();
299
300 assert_eq!(stats.memcached_errors.load(Ordering::Relaxed), 2);
301 assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
302 }
303}