sentinel_proxy/inference/
rate_limit.rs

1//! Token-based rate limiting for inference endpoints
2//!
3//! Provides dual-bucket rate limiting that tracks both:
4//! - Tokens per minute (primary limit for LLM APIs)
5//! - Requests per minute (secondary limit to prevent abuse)
6
7use dashmap::DashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10use tracing::{debug, trace};
11
12use sentinel_config::TokenRateLimit;
13
14/// Result of a rate limit check
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TokenRateLimitResult {
17    /// Request is allowed
18    Allowed,
19    /// Token limit exceeded
20    TokensExceeded {
21        /// Milliseconds until retry is allowed
22        retry_after_ms: u64,
23    },
24    /// Request limit exceeded
25    RequestsExceeded {
26        /// Milliseconds until retry is allowed
27        retry_after_ms: u64,
28    },
29}
30
31impl TokenRateLimitResult {
32    /// Returns true if the request is allowed
33    pub fn is_allowed(&self) -> bool {
34        matches!(self, Self::Allowed)
35    }
36
37    /// Get retry-after value in milliseconds (0 if allowed)
38    pub fn retry_after_ms(&self) -> u64 {
39        match self {
40            Self::Allowed => 0,
41            Self::TokensExceeded { retry_after_ms } => *retry_after_ms,
42            Self::RequestsExceeded { retry_after_ms } => *retry_after_ms,
43        }
44    }
45}
46
47/// Token bucket for rate limiting
48struct TokenBucket {
49    /// Current token count
50    tokens: AtomicU64,
51    /// Maximum tokens (burst capacity)
52    max_tokens: u64,
53    /// Tokens added per millisecond
54    refill_rate: f64,
55    /// Last refill timestamp
56    last_refill: std::sync::Mutex<Instant>,
57}
58
59impl TokenBucket {
60    fn new(tokens_per_minute: u64, burst_tokens: u64) -> Self {
61        // Calculate refill rate: tokens per millisecond
62        let refill_rate = tokens_per_minute as f64 / 60_000.0;
63
64        Self {
65            tokens: AtomicU64::new(burst_tokens),
66            max_tokens: burst_tokens,
67            refill_rate,
68            last_refill: std::sync::Mutex::new(Instant::now()),
69        }
70    }
71
72    /// Try to consume tokens from the bucket
73    fn try_consume(&self, amount: u64) -> Result<(), u64> {
74        // First, refill based on elapsed time
75        self.refill();
76
77        // Try to consume
78        loop {
79            let current = self.tokens.load(Ordering::Acquire);
80            if current < amount {
81                // Not enough tokens - calculate wait time
82                let needed = amount - current;
83                let wait_ms = (needed as f64 / self.refill_rate).ceil() as u64;
84                return Err(wait_ms);
85            }
86
87            // Try to atomically subtract
88            if self
89                .tokens
90                .compare_exchange(current, current - amount, Ordering::AcqRel, Ordering::Relaxed)
91                .is_ok()
92            {
93                return Ok(());
94            }
95            // CAS failed, retry
96        }
97    }
98
99    /// Refill tokens based on elapsed time
100    fn refill(&self) {
101        let mut last = self.last_refill.lock().unwrap();
102        let now = Instant::now();
103        let elapsed = now.duration_since(*last);
104
105        if elapsed.as_millis() > 0 {
106            let refill_amount = (elapsed.as_millis() as f64 * self.refill_rate) as u64;
107            if refill_amount > 0 {
108                let current = self.tokens.load(Ordering::Acquire);
109                let new_tokens = (current + refill_amount).min(self.max_tokens);
110                self.tokens.store(new_tokens, Ordering::Release);
111                *last = now;
112            }
113        }
114    }
115
116    /// Get current token count
117    fn current_tokens(&self) -> u64 {
118        self.refill();
119        self.tokens.load(Ordering::Acquire)
120    }
121}
122
123/// Token-based rate limiter for inference endpoints
124///
125/// Tracks rate limits per key (typically client IP or API key).
126pub struct TokenRateLimiter {
127    /// Token buckets per key
128    token_buckets: DashMap<String, TokenBucket>,
129    /// Request buckets per key (optional)
130    request_buckets: Option<DashMap<String, TokenBucket>>,
131    /// Configuration
132    config: TokenRateLimit,
133}
134
135impl TokenRateLimiter {
136    /// Create a new token rate limiter
137    pub fn new(config: TokenRateLimit) -> Self {
138        let request_buckets = config.requests_per_minute.map(|rpm| {
139            DashMap::new()
140        });
141
142        Self {
143            token_buckets: DashMap::new(),
144            request_buckets,
145            config,
146        }
147    }
148
149    /// Check if a request is allowed
150    ///
151    /// Both token and request limits must pass for the request to be allowed.
152    pub fn check(&self, key: &str, estimated_tokens: u64) -> TokenRateLimitResult {
153        // Check token limit
154        let token_bucket = self.token_buckets.entry(key.to_string()).or_insert_with(|| {
155            TokenBucket::new(self.config.tokens_per_minute, self.config.burst_tokens)
156        });
157
158        if let Err(retry_ms) = token_bucket.try_consume(estimated_tokens) {
159            trace!(
160                key = key,
161                estimated_tokens = estimated_tokens,
162                retry_after_ms = retry_ms,
163                "Token rate limit exceeded"
164            );
165            return TokenRateLimitResult::TokensExceeded {
166                retry_after_ms: retry_ms,
167            };
168        }
169
170        // Check request limit if configured
171        if let (Some(rpm), Some(ref request_buckets)) = (self.config.requests_per_minute, &self.request_buckets) {
172            let request_bucket = request_buckets.entry(key.to_string()).or_insert_with(|| {
173                // For request limiting, use burst = rpm / 6 (10 second burst)
174                let burst = rpm.max(1) / 6;
175                TokenBucket::new(rpm, burst.max(1))
176            });
177
178            if let Err(retry_ms) = request_bucket.try_consume(1) {
179                trace!(
180                    key = key,
181                    retry_after_ms = retry_ms,
182                    "Request rate limit exceeded"
183                );
184                return TokenRateLimitResult::RequestsExceeded {
185                    retry_after_ms: retry_ms,
186                };
187            }
188        }
189
190        trace!(
191            key = key,
192            estimated_tokens = estimated_tokens,
193            "Rate limit check passed"
194        );
195        TokenRateLimitResult::Allowed
196    }
197
198    /// Record actual token usage after response
199    ///
200    /// This allows adjusting the bucket based on actual vs estimated usage.
201    /// If actual < estimated, refund the difference.
202    /// If actual > estimated, consume the extra (best effort).
203    pub fn record_actual(&self, key: &str, actual_tokens: u64, estimated_tokens: u64) {
204        if let Some(bucket) = self.token_buckets.get(key) {
205            if actual_tokens < estimated_tokens {
206                // Refund over-estimation
207                let refund = estimated_tokens - actual_tokens;
208                let current = bucket.tokens.load(Ordering::Acquire);
209                let new_tokens = (current + refund).min(bucket.max_tokens);
210                bucket.tokens.store(new_tokens, Ordering::Release);
211
212                debug!(
213                    key = key,
214                    actual = actual_tokens,
215                    estimated = estimated_tokens,
216                    refund = refund,
217                    "Refunded over-estimated tokens"
218                );
219            } else if actual_tokens > estimated_tokens {
220                // Under-estimation - try to consume extra (don't block)
221                let extra = actual_tokens - estimated_tokens;
222                let current = bucket.tokens.load(Ordering::Acquire);
223                let to_consume = extra.min(current);
224                if to_consume > 0 {
225                    bucket.tokens.fetch_sub(to_consume, Ordering::AcqRel);
226                }
227
228                debug!(
229                    key = key,
230                    actual = actual_tokens,
231                    estimated = estimated_tokens,
232                    consumed_extra = to_consume,
233                    "Consumed under-estimated tokens"
234                );
235            }
236        }
237    }
238
239    /// Get current token count for a key
240    pub fn current_tokens(&self, key: &str) -> Option<u64> {
241        self.token_buckets.get(key).map(|b| b.current_tokens())
242    }
243
244    /// Get stats for metrics
245    pub fn stats(&self) -> TokenRateLimiterStats {
246        TokenRateLimiterStats {
247            active_keys: self.token_buckets.len(),
248            tokens_per_minute: self.config.tokens_per_minute,
249            requests_per_minute: self.config.requests_per_minute,
250        }
251    }
252}
253
254/// Stats for the token rate limiter
255#[derive(Debug, Clone)]
256pub struct TokenRateLimiterStats {
257    /// Number of active rate limit keys
258    pub active_keys: usize,
259    /// Configured tokens per minute
260    pub tokens_per_minute: u64,
261    /// Configured requests per minute (if any)
262    pub requests_per_minute: Option<u64>,
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use sentinel_config::TokenEstimation;
269
270    fn test_config() -> TokenRateLimit {
271        TokenRateLimit {
272            tokens_per_minute: 1000,
273            requests_per_minute: Some(10),
274            burst_tokens: 200,
275            estimation_method: TokenEstimation::Chars,
276        }
277    }
278
279    #[test]
280    fn test_basic_rate_limiting() {
281        let limiter = TokenRateLimiter::new(test_config());
282
283        // First request should succeed
284        let result = limiter.check("test-key", 50);
285        assert!(result.is_allowed());
286
287        // Should still have tokens
288        let current = limiter.current_tokens("test-key").unwrap();
289        assert!(current > 0);
290    }
291
292    #[test]
293    fn test_token_exhaustion() {
294        let limiter = TokenRateLimiter::new(test_config());
295
296        // Exhaust tokens
297        for _ in 0..4 {
298            let _ = limiter.check("test-key", 50);
299        }
300
301        // This should exceed the 200 burst tokens
302        let result = limiter.check("test-key", 50);
303        assert!(!result.is_allowed());
304        assert!(matches!(result, TokenRateLimitResult::TokensExceeded { .. }));
305    }
306
307    #[test]
308    fn test_actual_token_refund() {
309        let limiter = TokenRateLimiter::new(test_config());
310
311        // Consume with high estimate
312        let _ = limiter.check("test-key", 100);
313        let before = limiter.current_tokens("test-key").unwrap();
314
315        // Record actual as lower
316        limiter.record_actual("test-key", 50, 100);
317        let after = limiter.current_tokens("test-key").unwrap();
318
319        // Should have refunded 50 tokens
320        assert!(after > before);
321    }
322}