1use {
6 cfg_if::cfg_if,
7 dashmap::{mapref::entry::Entry, DashMap},
8 solana_svm_type_overrides::sync::atomic::{AtomicU64, AtomicUsize, Ordering},
9 std::{borrow::Borrow, cmp::Reverse, hash::Hash, time::Instant},
10};
11
12pub struct TokenBucket {
18 new_tokens_per_us: f64,
19 max_tokens: u64,
20 base_time: Instant,
22 tokens: AtomicU64,
23 last_update: AtomicU64,
25 credit_time_us: AtomicU64,
27}
28
29#[cfg(feature = "shuttle-test")]
30static TIME_US: AtomicU64 = AtomicU64::new(0); impl TokenBucket {
35 pub fn new(initial_tokens: u64, max_tokens: u64, new_tokens_per_second: f64) -> Self {
37 assert!(
38 new_tokens_per_second > 0.0,
39 "Token bucket can not have zero influx rate"
40 );
41 assert!(
42 initial_tokens <= max_tokens,
43 "Can not have more initial tokens than max tokens"
44 );
45 let base_time = Instant::now();
46 TokenBucket {
47 new_tokens_per_us: new_tokens_per_second / 1e6,
49 max_tokens,
50 tokens: AtomicU64::new(initial_tokens),
51 last_update: AtomicU64::new(0),
52 base_time,
53 credit_time_us: AtomicU64::new(0),
54 }
55 }
56
57 #[inline]
61 pub fn current_tokens(&self) -> u64 {
62 let now = self.time_us();
63 self.update_state(now);
64 self.tokens.load(Ordering::Relaxed)
65 }
66
67 #[inline]
72 pub fn consume_tokens(&self, request_size: u64) -> Result<u64, u64> {
73 let now = self.time_us();
74 self.update_state(now);
75 match self.tokens.fetch_update(
76 Ordering::AcqRel, Ordering::Acquire, |tokens| {
79 if tokens >= request_size {
80 Some(tokens.saturating_sub(request_size))
81 } else {
82 None
83 }
84 },
85 ) {
86 Ok(prev) => Ok(prev.saturating_sub(request_size)),
87 Err(prev) => Err(request_size.saturating_sub(prev)),
88 }
89 }
90
91 fn time_us(&self) -> u64 {
93 cfg_if! {
94 if #[cfg(feature="shuttle-test")] {
95 TIME_US.load(Ordering::Relaxed)
96 } else {
97 let now = Instant::now();
98 let elapsed = now.saturating_duration_since(self.base_time);
99 elapsed.as_micros() as u64
100 }
101 }
102 }
103
104 fn update_state(&self, now: u64) {
107 let last = self.last_update.load(Ordering::SeqCst);
109
110 if now <= last {
112 return;
113 }
114
115 match self.last_update.compare_exchange(
121 last,
122 now,
123 Ordering::AcqRel, Ordering::Acquire, ) {
126 Ok(_) => {
127 let elapsed = now.saturating_sub(last);
129
130 let elapsed =
133 elapsed.saturating_add(self.credit_time_us.swap(0, Ordering::Relaxed));
134
135 let new_tokens_f64 = elapsed as f64 * self.new_tokens_per_us;
136
137 let new_tokens = new_tokens_f64.floor() as u64;
139
140 let time_to_return = if new_tokens >= 1 {
141 let _ = self.tokens.fetch_update(
143 Ordering::AcqRel, Ordering::Acquire, |tokens| Some(tokens.saturating_add(new_tokens).min(self.max_tokens)),
146 );
147 (new_tokens_f64.fract() / self.new_tokens_per_us) as u64
150 } else {
151 elapsed
153 };
154 self.credit_time_us
156 .fetch_add(time_to_return, Ordering::Relaxed);
157 }
158 Err(_) => {
159 }
161 }
162 }
163}
164
165impl Clone for TokenBucket {
166 fn clone(&self) -> Self {
170 Self {
171 new_tokens_per_us: self.new_tokens_per_us,
172 max_tokens: self.max_tokens,
173 base_time: self.base_time,
174 tokens: AtomicU64::new(self.tokens.load(Ordering::Relaxed)),
175 last_update: AtomicU64::new(self.last_update.load(Ordering::Relaxed)),
176 credit_time_us: AtomicU64::new(self.credit_time_us.load(Ordering::Relaxed)),
177 }
178 }
179}
180
181pub struct KeyedRateLimiter<K>
190where
191 K: Hash + Eq,
192{
193 data: DashMap<K, TokenBucket>,
194 target_capacity: usize,
195 prototype_bucket: TokenBucket,
196 countdown_to_shrink: AtomicUsize,
197 approx_len: AtomicUsize,
198 shrink_interval: usize,
199}
200
201impl<K> KeyedRateLimiter<K>
202where
203 K: Hash + Eq,
204{
205 #[allow(clippy::arithmetic_side_effects)]
212 pub fn new(target_capacity: usize, prototype_bucket: TokenBucket, shard_amount: usize) -> Self {
213 let shrink_interval = target_capacity / 4;
214 Self {
215 data: DashMap::with_capacity_and_shard_amount(target_capacity * 2, shard_amount),
216 target_capacity,
217 prototype_bucket,
218 countdown_to_shrink: AtomicUsize::new(shrink_interval),
219 approx_len: AtomicUsize::new(0),
220 shrink_interval,
221 }
222 }
223
224 #[inline]
228 pub fn current_tokens(&self, key: impl Borrow<K>) -> Option<u64> {
229 let bucket = self.data.get(key.borrow())?;
230 Some(bucket.current_tokens())
231 }
232
233 pub fn consume_tokens(&self, key: K, request_size: u64) -> Result<u64, u64> {
240 let (entry_added, res) = {
241 let bucket = self.data.entry(key);
242 match bucket {
243 Entry::Occupied(entry) => (false, entry.get().consume_tokens(request_size)),
244 Entry::Vacant(entry) => {
245 let bucket = self.prototype_bucket.clone();
247 let res = bucket.consume_tokens(request_size);
248 entry.insert(bucket);
249 (true, res)
250 }
251 }
252 };
253
254 if entry_added {
255 if let Ok(count) =
256 self.countdown_to_shrink
257 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
258 if v == 0 {
259 None
262 } else {
263 Some(v.saturating_sub(1))
264 }
265 })
266 {
267 if count == 1 {
268 self.maybe_shrink();
270 self.countdown_to_shrink
271 .store(self.shrink_interval, Ordering::Relaxed);
272 }
273 } else {
274 self.approx_len.fetch_add(1, Ordering::Relaxed);
275 }
276 }
277 res
278 }
279
280 #[inline]
283 pub fn len_approx(&self) -> usize {
284 self.approx_len.load(Ordering::Relaxed)
285 }
286
287 #[allow(clippy::arithmetic_side_effects)]
291 fn maybe_shrink(&self) {
292 let mut actual_len = 0;
293 let target_shard_size = self.target_capacity / self.data.shards().len();
294 let mut entries = Vec::with_capacity(target_shard_size * 2);
295 for shardlock in self.data.shards() {
296 let mut shard = shardlock.write();
297
298 if shard.len() <= target_shard_size * 3 / 2 {
299 actual_len += shard.len();
300 continue;
301 }
302 entries.clear();
303 entries.extend(
304 shard.drain().map(|(key, value)| {
305 (key, value.get().last_update.load(Ordering::SeqCst), value)
306 }),
307 );
308
309 entries.select_nth_unstable_by_key(target_shard_size, |(_, last_update, _)| {
310 Reverse(*last_update)
311 });
312
313 shard.extend(
314 entries
315 .drain(..)
316 .take(target_shard_size)
317 .map(|(key, _last_update, value)| (key, value)),
318 );
319 debug_assert!(shard.len() <= target_shard_size);
320 actual_len += shard.len();
321 }
322 self.approx_len.store(actual_len, Ordering::Relaxed);
323 }
324
325 pub fn set_shrink_interval(&mut self, interval: usize) {
330 self.shrink_interval = interval;
331 }
332
333 pub fn shrink_interval(&self) -> usize {
335 self.shrink_interval
336 }
337}
338
339#[cfg(test)]
340pub mod test {
341 use {
342 super::*,
343 solana_svm_type_overrides::thread,
344 std::{
345 net::{IpAddr, Ipv4Addr},
346 time::Duration,
347 },
348 };
349
350 #[test]
351 fn test_token_bucket() {
352 let tb = TokenBucket::new(100, 100, 1000.0);
353 assert_eq!(tb.current_tokens(), 100);
354 tb.consume_tokens(50).expect("Bucket is initially full");
355 tb.consume_tokens(50)
356 .expect("We should still have >50 tokens left");
357 tb.consume_tokens(50)
358 .expect_err("There should not be enough tokens now");
359 thread::sleep(Duration::from_millis(50));
360 assert!(
361 tb.current_tokens() > 40,
362 "We should be refilling at ~1 token per millisecond"
363 );
364 assert!(
365 tb.current_tokens() < 70,
366 "We should be refilling at ~1 token per millisecond"
367 );
368 tb.consume_tokens(40)
369 .expect("Bucket should have enough for another request now");
370 thread::sleep(Duration::from_millis(120));
371 assert_eq!(tb.current_tokens(), 100, "Bucket should not overfill");
372 }
373 #[test]
374 fn test_keyed_rate_limiter() {
375 let prototype_bucket = TokenBucket::new(100, 100, 1000.0);
376 let rl = KeyedRateLimiter::new(8, prototype_bucket, 2);
377 let ip1 = IpAddr::V4(Ipv4Addr::from_bits(1234));
378 let ip2 = IpAddr::V4(Ipv4Addr::from_bits(4321));
379 assert_eq!(rl.current_tokens(ip1), None, "Initially no buckets exist");
380 rl.consume_tokens(ip1, 50)
381 .expect("Bucket is initially full");
382 rl.consume_tokens(ip1, 50)
383 .expect("We should still have >50 tokens left");
384 rl.consume_tokens(ip1, 50)
385 .expect_err("There should not be enough tokens now");
386 rl.consume_tokens(ip2, 50)
387 .expect("Bucket is initially full");
388 rl.consume_tokens(ip2, 50)
389 .expect("We should still have >50 tokens left");
390 rl.consume_tokens(ip2, 50)
391 .expect_err("There should not be enough tokens now");
392 std::thread::sleep(Duration::from_millis(50));
393 assert!(
394 rl.current_tokens(ip1).unwrap() > 40,
395 "We should be refilling at ~1 token per millisecond"
396 );
397 assert!(
398 rl.current_tokens(ip1).unwrap() < 70,
399 "We should be refilling at ~1 token per millisecond"
400 );
401 rl.consume_tokens(ip1, 40)
402 .expect("Bucket should have enough for another request now");
403 thread::sleep(Duration::from_millis(120));
404 assert_eq!(
405 rl.current_tokens(ip1),
406 Some(100),
407 "Bucket should not overfill"
408 );
409 assert_eq!(
410 rl.current_tokens(ip2),
411 Some(100),
412 "Bucket should not overfill"
413 );
414
415 rl.consume_tokens(ip2, 100).expect("Bucket should be full");
416 for ip in 0..64 {
419 let ip = IpAddr::V4(Ipv4Addr::from_bits(ip));
420 rl.consume_tokens(ip, 50).unwrap();
421 }
422 assert_eq!(
423 rl.current_tokens(ip1),
424 None,
425 "Very old record should have been erased"
426 );
427 rl.consume_tokens(ip2, 100)
428 .expect("New bucket should have been made for ip2");
429 }
430
431 #[cfg(feature = "shuttle-test")]
432 #[test]
433 fn shuttle_test_token_bucket_race() {
434 use shuttle::sync::atomic::AtomicBool;
435 shuttle::check_random(
436 || {
437 TIME_US.store(0, Ordering::SeqCst);
438 let test_duration_us = 2500;
439 let run: &AtomicBool = Box::leak(Box::new(AtomicBool::new(true)));
440 let tb: &TokenBucket = Box::leak(Box::new(TokenBucket::new(10, 20, 5000.0)));
441
442 let time_advancer = thread::spawn(move || {
444 let mut current_time = 0;
445 while current_time < test_duration_us && run.load(Ordering::SeqCst) {
446 let increment = 100; current_time += increment;
448 TIME_US.store(current_time, Ordering::SeqCst);
449 shuttle::thread::yield_now();
450 }
451 run.store(false, Ordering::SeqCst);
452 });
453
454 let threads: Vec<_> = (0..2)
455 .map(|_| {
456 thread::spawn(move || {
457 let mut total = 0;
458 while run.load(Ordering::SeqCst) {
459 if tb.consume_tokens(5).is_ok() {
460 total += 1;
461 }
462 shuttle::thread::yield_now();
463 }
464 total
465 })
466 })
467 .collect();
468
469 time_advancer.join().unwrap();
470 let received = threads.into_iter().map(|t| t.join().unwrap()).sum();
471
472 assert_eq!(4, received);
476 },
477 100,
478 );
479 }
480}