sentinel_proxy/inference/
rate_limit.rs1use dashmap::DashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10use tracing::{debug, trace};
11
12use sentinel_config::TokenRateLimit;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TokenRateLimitResult {
17 Allowed,
19 TokensExceeded {
21 retry_after_ms: u64,
23 },
24 RequestsExceeded {
26 retry_after_ms: u64,
28 },
29}
30
31impl TokenRateLimitResult {
32 pub fn is_allowed(&self) -> bool {
34 matches!(self, Self::Allowed)
35 }
36
37 pub fn retry_after_ms(&self) -> u64 {
39 match self {
40 Self::Allowed => 0,
41 Self::TokensExceeded { retry_after_ms } => *retry_after_ms,
42 Self::RequestsExceeded { retry_after_ms } => *retry_after_ms,
43 }
44 }
45}
46
47struct TokenBucket {
49 tokens: AtomicU64,
51 max_tokens: u64,
53 refill_rate: f64,
55 last_refill: std::sync::Mutex<Instant>,
57}
58
59impl TokenBucket {
60 fn new(tokens_per_minute: u64, burst_tokens: u64) -> Self {
61 let refill_rate = tokens_per_minute as f64 / 60_000.0;
63
64 Self {
65 tokens: AtomicU64::new(burst_tokens),
66 max_tokens: burst_tokens,
67 refill_rate,
68 last_refill: std::sync::Mutex::new(Instant::now()),
69 }
70 }
71
72 fn try_consume(&self, amount: u64) -> Result<(), u64> {
74 self.refill();
76
77 loop {
79 let current = self.tokens.load(Ordering::Acquire);
80 if current < amount {
81 let needed = amount - current;
83 let wait_ms = (needed as f64 / self.refill_rate).ceil() as u64;
84 return Err(wait_ms);
85 }
86
87 if self
89 .tokens
90 .compare_exchange(current, current - amount, Ordering::AcqRel, Ordering::Relaxed)
91 .is_ok()
92 {
93 return Ok(());
94 }
95 }
97 }
98
99 fn refill(&self) {
101 let mut last = self.last_refill.lock().unwrap();
102 let now = Instant::now();
103 let elapsed = now.duration_since(*last);
104
105 if elapsed.as_millis() > 0 {
106 let refill_amount = (elapsed.as_millis() as f64 * self.refill_rate) as u64;
107 if refill_amount > 0 {
108 let current = self.tokens.load(Ordering::Acquire);
109 let new_tokens = (current + refill_amount).min(self.max_tokens);
110 self.tokens.store(new_tokens, Ordering::Release);
111 *last = now;
112 }
113 }
114 }
115
116 fn current_tokens(&self) -> u64 {
118 self.refill();
119 self.tokens.load(Ordering::Acquire)
120 }
121}
122
123pub struct TokenRateLimiter {
127 token_buckets: DashMap<String, TokenBucket>,
129 request_buckets: Option<DashMap<String, TokenBucket>>,
131 config: TokenRateLimit,
133}
134
135impl TokenRateLimiter {
136 pub fn new(config: TokenRateLimit) -> Self {
138 let request_buckets = config.requests_per_minute.map(|rpm| {
139 DashMap::new()
140 });
141
142 Self {
143 token_buckets: DashMap::new(),
144 request_buckets,
145 config,
146 }
147 }
148
149 pub fn check(&self, key: &str, estimated_tokens: u64) -> TokenRateLimitResult {
153 let token_bucket = self.token_buckets.entry(key.to_string()).or_insert_with(|| {
155 TokenBucket::new(self.config.tokens_per_minute, self.config.burst_tokens)
156 });
157
158 if let Err(retry_ms) = token_bucket.try_consume(estimated_tokens) {
159 trace!(
160 key = key,
161 estimated_tokens = estimated_tokens,
162 retry_after_ms = retry_ms,
163 "Token rate limit exceeded"
164 );
165 return TokenRateLimitResult::TokensExceeded {
166 retry_after_ms: retry_ms,
167 };
168 }
169
170 if let (Some(rpm), Some(ref request_buckets)) = (self.config.requests_per_minute, &self.request_buckets) {
172 let request_bucket = request_buckets.entry(key.to_string()).or_insert_with(|| {
173 let burst = rpm.max(1) / 6;
175 TokenBucket::new(rpm, burst.max(1))
176 });
177
178 if let Err(retry_ms) = request_bucket.try_consume(1) {
179 trace!(
180 key = key,
181 retry_after_ms = retry_ms,
182 "Request rate limit exceeded"
183 );
184 return TokenRateLimitResult::RequestsExceeded {
185 retry_after_ms: retry_ms,
186 };
187 }
188 }
189
190 trace!(
191 key = key,
192 estimated_tokens = estimated_tokens,
193 "Rate limit check passed"
194 );
195 TokenRateLimitResult::Allowed
196 }
197
198 pub fn record_actual(&self, key: &str, actual_tokens: u64, estimated_tokens: u64) {
204 if let Some(bucket) = self.token_buckets.get(key) {
205 if actual_tokens < estimated_tokens {
206 let refund = estimated_tokens - actual_tokens;
208 let current = bucket.tokens.load(Ordering::Acquire);
209 let new_tokens = (current + refund).min(bucket.max_tokens);
210 bucket.tokens.store(new_tokens, Ordering::Release);
211
212 debug!(
213 key = key,
214 actual = actual_tokens,
215 estimated = estimated_tokens,
216 refund = refund,
217 "Refunded over-estimated tokens"
218 );
219 } else if actual_tokens > estimated_tokens {
220 let extra = actual_tokens - estimated_tokens;
222 let current = bucket.tokens.load(Ordering::Acquire);
223 let to_consume = extra.min(current);
224 if to_consume > 0 {
225 bucket.tokens.fetch_sub(to_consume, Ordering::AcqRel);
226 }
227
228 debug!(
229 key = key,
230 actual = actual_tokens,
231 estimated = estimated_tokens,
232 consumed_extra = to_consume,
233 "Consumed under-estimated tokens"
234 );
235 }
236 }
237 }
238
239 pub fn current_tokens(&self, key: &str) -> Option<u64> {
241 self.token_buckets.get(key).map(|b| b.current_tokens())
242 }
243
244 pub fn stats(&self) -> TokenRateLimiterStats {
246 TokenRateLimiterStats {
247 active_keys: self.token_buckets.len(),
248 tokens_per_minute: self.config.tokens_per_minute,
249 requests_per_minute: self.config.requests_per_minute,
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct TokenRateLimiterStats {
257 pub active_keys: usize,
259 pub tokens_per_minute: u64,
261 pub requests_per_minute: Option<u64>,
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use sentinel_config::TokenEstimation;
269
270 fn test_config() -> TokenRateLimit {
271 TokenRateLimit {
272 tokens_per_minute: 1000,
273 requests_per_minute: Some(10),
274 burst_tokens: 200,
275 estimation_method: TokenEstimation::Chars,
276 }
277 }
278
279 #[test]
280 fn test_basic_rate_limiting() {
281 let limiter = TokenRateLimiter::new(test_config());
282
283 let result = limiter.check("test-key", 50);
285 assert!(result.is_allowed());
286
287 let current = limiter.current_tokens("test-key").unwrap();
289 assert!(current > 0);
290 }
291
292 #[test]
293 fn test_token_exhaustion() {
294 let limiter = TokenRateLimiter::new(test_config());
295
296 for _ in 0..4 {
298 let _ = limiter.check("test-key", 50);
299 }
300
301 let result = limiter.check("test-key", 50);
303 assert!(!result.is_allowed());
304 assert!(matches!(result, TokenRateLimitResult::TokensExceeded { .. }));
305 }
306
307 #[test]
308 fn test_actual_token_refund() {
309 let limiter = TokenRateLimiter::new(test_config());
310
311 let _ = limiter.check("test-key", 100);
313 let before = limiter.current_tokens("test-key").unwrap();
314
315 limiter.record_actual("test-key", 50, 100);
317 let after = limiter.current_tokens("test-key").unwrap();
318
319 assert!(after > before);
321 }
322}