1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20#[cfg(feature = "distributed-rate-limit-memcached")]
21use async_memcached::AsciiProtocol;
22use parking_lot::RwLock;
23use tracing::{debug, error, trace, warn};
24
25use sentinel_config::MemcachedBackendConfig;
26
27use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
28
29#[derive(Debug, Default)]
31pub struct MemcachedRateLimitStats {
32 pub total_checks: AtomicU64,
34 pub allowed: AtomicU64,
36 pub limited: AtomicU64,
38 pub memcached_errors: AtomicU64,
40 pub local_fallbacks: AtomicU64,
42}
43
44impl MemcachedRateLimitStats {
45 pub fn record_check(&self, outcome: RateLimitOutcome) {
46 self.total_checks.fetch_add(1, Ordering::Relaxed);
47 match outcome {
48 RateLimitOutcome::Allowed => {
49 self.allowed.fetch_add(1, Ordering::Relaxed);
50 }
51 RateLimitOutcome::Limited => {
52 self.limited.fetch_add(1, Ordering::Relaxed);
53 }
54 }
55 }
56
57 pub fn record_memcached_error(&self) {
58 self.memcached_errors.fetch_add(1, Ordering::Relaxed);
59 }
60
61 pub fn record_local_fallback(&self) {
62 self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
63 }
64}
65
66#[cfg(feature = "distributed-rate-limit-memcached")]
68pub struct MemcachedRateLimiter {
69 client: RwLock<async_memcached::Client>,
71 config: RwLock<MemcachedConfig>,
73 healthy: AtomicBool,
75 pub stats: Arc<MemcachedRateLimitStats>,
77}
78
79#[cfg(feature = "distributed-rate-limit-memcached")]
80#[derive(Debug, Clone)]
81struct MemcachedConfig {
82 key_prefix: String,
83 max_rps: u32,
84 window_secs: u64,
85 timeout: Duration,
86 fallback_local: bool,
87 ttl_secs: u32,
88}
89
90#[cfg(feature = "distributed-rate-limit-memcached")]
91impl MemcachedRateLimiter {
92 pub async fn new(
94 backend_config: &MemcachedBackendConfig,
95 rate_config: &RateLimitConfig,
96 ) -> Result<Self, async_memcached::Error> {
97 let addr = backend_config
99 .url
100 .trim_start_matches("memcache://")
101 .trim_start_matches("memcached://");
102
103 let client = async_memcached::Client::new(addr).await?;
104
105 debug!(
106 url = %backend_config.url,
107 prefix = %backend_config.key_prefix,
108 max_rps = rate_config.max_rps,
109 "Memcached rate limiter initialized"
110 );
111
112 Ok(Self {
113 client: RwLock::new(client),
114 config: RwLock::new(MemcachedConfig {
115 key_prefix: backend_config.key_prefix.clone(),
116 max_rps: rate_config.max_rps,
117 window_secs: 1,
118 timeout: Duration::from_millis(backend_config.timeout_ms),
119 fallback_local: backend_config.fallback_local,
120 ttl_secs: backend_config.ttl_secs,
121 }),
122 healthy: AtomicBool::new(true),
123 stats: Arc::new(MemcachedRateLimitStats::default()),
124 })
125 }
126
127 pub async fn check(
131 &self,
132 key: &str,
133 ) -> Result<(RateLimitOutcome, u64), async_memcached::Error> {
134 let config = self.config.read().clone();
135
136 let now = std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .unwrap()
140 .as_secs();
141 let window_key = format!("{}{}:{}", config.key_prefix, key, now);
142
143 let mut client = self.client.write();
145
146 let result = tokio::time::timeout(config.timeout, async {
147 match client.increment(&window_key, 1).await {
149 Ok(count) => Ok(count),
150 Err(async_memcached::Error::Protocol(async_memcached::Status::NotFound)) => {
151 client
153 .set(&window_key, &b"1"[..], Some(config.ttl_secs as i64), None)
154 .await
155 .map(|_| 1u64)
156 }
157 Err(e) => Err(e),
158 }
159 })
160 .await
161 .map_err(|_| {
162 async_memcached::Error::Io(std::io::Error::new(
163 std::io::ErrorKind::TimedOut,
164 "Memcached operation timed out",
165 ))
166 })??;
167
168 self.healthy.store(true, Ordering::Relaxed);
169
170 let outcome = if result > config.max_rps as u64 {
171 RateLimitOutcome::Limited
172 } else {
173 RateLimitOutcome::Allowed
174 };
175
176 trace!(
177 key = key,
178 count = result,
179 max_rps = config.max_rps,
180 outcome = ?outcome,
181 "Memcached rate limit check"
182 );
183
184 self.stats.record_check(outcome);
185 Ok((outcome, result))
186 }
187
188 pub fn update_config(
190 &self,
191 backend_config: &MemcachedBackendConfig,
192 rate_config: &RateLimitConfig,
193 ) {
194 let mut config = self.config.write();
195 config.key_prefix = backend_config.key_prefix.clone();
196 config.max_rps = rate_config.max_rps;
197 config.timeout = Duration::from_millis(backend_config.timeout_ms);
198 config.fallback_local = backend_config.fallback_local;
199 config.ttl_secs = backend_config.ttl_secs;
200 }
201
202 pub fn is_healthy(&self) -> bool {
204 self.healthy.load(Ordering::Relaxed)
205 }
206
207 pub fn mark_unhealthy(&self) {
209 self.healthy.store(false, Ordering::Relaxed);
210 self.stats.record_memcached_error();
211 }
212
213 pub fn fallback_enabled(&self) -> bool {
215 self.config.read().fallback_local
216 }
217}
218
219#[cfg(not(feature = "distributed-rate-limit-memcached"))]
221pub struct MemcachedRateLimiter;
222
223#[cfg(not(feature = "distributed-rate-limit-memcached"))]
224impl MemcachedRateLimiter {
225 pub async fn new(
226 _backend_config: &MemcachedBackendConfig,
227 _rate_config: &RateLimitConfig,
228 ) -> Result<Self, String> {
229 Err(
230 "Memcached rate limiting requires the 'distributed-rate-limit-memcached' feature"
231 .to_string(),
232 )
233 }
234}
235
236#[cfg(feature = "distributed-rate-limit-memcached")]
238pub async fn create_memcached_rate_limiter(
239 backend_config: &MemcachedBackendConfig,
240 rate_config: &RateLimitConfig,
241) -> Option<MemcachedRateLimiter> {
242 match MemcachedRateLimiter::new(backend_config, rate_config).await {
243 Ok(limiter) => {
244 debug!(
245 url = %backend_config.url,
246 "Memcached rate limiter created successfully"
247 );
248 Some(limiter)
249 }
250 Err(e) => {
251 error!(
252 error = %e,
253 url = %backend_config.url,
254 "Failed to create Memcached rate limiter"
255 );
256 if backend_config.fallback_local {
257 warn!("Falling back to local rate limiting");
258 }
259 None
260 }
261 }
262}
263
264#[cfg(not(feature = "distributed-rate-limit-memcached"))]
265pub async fn create_memcached_rate_limiter(
266 _backend_config: &MemcachedBackendConfig,
267 _rate_config: &RateLimitConfig,
268) -> Option<MemcachedRateLimiter> {
269 warn!(
270 "Memcached rate limiting requested but feature is disabled. Using local rate limiting."
271 );
272 None
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_stats_recording() {
281 let stats = MemcachedRateLimitStats::default();
282
283 stats.record_check(RateLimitOutcome::Allowed);
284 stats.record_check(RateLimitOutcome::Allowed);
285 stats.record_check(RateLimitOutcome::Limited);
286
287 assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
288 assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
289 assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
290 }
291
292 #[test]
293 fn test_stats_memcached_errors() {
294 let stats = MemcachedRateLimitStats::default();
295
296 stats.record_memcached_error();
297 stats.record_memcached_error();
298 stats.record_local_fallback();
299
300 assert_eq!(stats.memcached_errors.load(Ordering::Relaxed), 2);
301 assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
302 }
303}