solana_net_utils/
token_bucket.rs

1//! This module contains [`TokenBucket`], which provides ability to limit
2//! rate of certain events, while allowing bursts through.
3//! [`KeyedRateLimiter`] allows to rate-limit multiple keyed items, such
4//! as connections.
5use {
6    cfg_if::cfg_if,
7    dashmap::{mapref::entry::Entry, DashMap},
8    solana_svm_type_overrides::sync::atomic::{AtomicU64, AtomicUsize, Ordering},
9    std::{borrow::Borrow, cmp::Reverse, hash::Hash, time::Instant},
10};
11
12/// Enforces a rate limit on the volume of requests per unit time.
13///
14/// Instances update the amount of tokens upon access, and thus does not need to
15/// be constantly polled to refill. Uses atomics internally so should be
16/// relatively cheap to access from many threads
17pub struct TokenBucket {
18    new_tokens_per_us: f64,
19    max_tokens: u64,
20    /// bucket creation
21    base_time: Instant,
22    tokens: AtomicU64,
23    /// time of last update in us since base_time
24    last_update: AtomicU64,
25    /// time unused in last token creation round
26    credit_time_us: AtomicU64,
27}
28
29#[cfg(feature = "shuttle-test")]
30static TIME_US: AtomicU64 = AtomicU64::new(0); //used to override Instant::now()
31
32// If changing this impl, make sure to run benches and ensure they do not panic.
33// much of the testing is impossible outside of real multithreading in release mode.
34impl TokenBucket {
35    /// Allocate a new TokenBucket
36    pub fn new(initial_tokens: u64, max_tokens: u64, new_tokens_per_second: f64) -> Self {
37        assert!(
38            new_tokens_per_second > 0.0,
39            "Token bucket can not have zero influx rate"
40        );
41        assert!(
42            initial_tokens <= max_tokens,
43            "Can not have more initial tokens than max tokens"
44        );
45        let base_time = Instant::now();
46        TokenBucket {
47            // recompute into us to avoid FP division on every update
48            new_tokens_per_us: new_tokens_per_second / 1e6,
49            max_tokens,
50            tokens: AtomicU64::new(initial_tokens),
51            last_update: AtomicU64::new(0),
52            base_time,
53            credit_time_us: AtomicU64::new(0),
54        }
55    }
56
57    /// Return current amount of tokens in the bucket.
58    /// This may be somewhat inconsistent across threads
59    /// due to Relaxed atomics.
60    #[inline]
61    pub fn current_tokens(&self) -> u64 {
62        let now = self.time_us();
63        self.update_state(now);
64        self.tokens.load(Ordering::Relaxed)
65    }
66
67    /// Attempts to consume tokens from bucket.
68    ///
69    /// On success, returns Ok(amount of tokens left in the bucket).
70    /// On failure, returns Err(amount of tokens missing to fill request).
71    #[inline]
72    pub fn consume_tokens(&self, request_size: u64) -> Result<u64, u64> {
73        let now = self.time_us();
74        self.update_state(now);
75        match self.tokens.fetch_update(
76            Ordering::AcqRel,  // winner publishes new amount
77            Ordering::Acquire, // everyone observed correct number
78            |tokens| {
79                if tokens >= request_size {
80                    Some(tokens.saturating_sub(request_size))
81                } else {
82                    None
83                }
84            },
85        ) {
86            Ok(prev) => Ok(prev.saturating_sub(request_size)),
87            Err(prev) => Err(request_size.saturating_sub(prev)),
88        }
89    }
90
91    /// Retrieves monotonic time since bucket creation.
92    fn time_us(&self) -> u64 {
93        cfg_if! {
94            if #[cfg(feature="shuttle-test")] {
95                TIME_US.load(Ordering::Relaxed)
96            } else {
97                let now = Instant::now();
98                let elapsed = now.saturating_duration_since(self.base_time);
99                elapsed.as_micros() as u64
100            }
101        }
102    }
103
104    /// Updates internal state of the bucket by
105    /// depositing new tokens (if appropriate)
106    fn update_state(&self, now: u64) {
107        // fetch last update time
108        let last = self.last_update.load(Ordering::SeqCst);
109
110        // If time has not advanced, nothing to do.
111        if now <= last {
112            return;
113        }
114
115        // Try to claim the interval [last, now].
116        // If we can not claim it, someone else will claim [last..some other time] when they
117        // touch the bucket.
118        // If we can claim interval [last, now], no other thread can credit tokens for it anymore.
119        // If [last, now] is too short to mint any tokens, spare time will be preserved in credit_time_us.
120        match self.last_update.compare_exchange(
121            last,
122            now,
123            Ordering::AcqRel,  // winner publishes new timestamp
124            Ordering::Acquire, // loser observes updates
125        ) {
126            Ok(_) => {
127                // This thread won the race and is responsible for minting tokens
128                let elapsed = now.saturating_sub(last);
129
130                // also add leftovers from previous conversion attempts.
131                // we do not care about who uses the spare_time_us, so relaxed is ok here.
132                let elapsed =
133                    elapsed.saturating_add(self.credit_time_us.swap(0, Ordering::Relaxed));
134
135                let new_tokens_f64 = elapsed as f64 * self.new_tokens_per_us;
136
137                // amount of full tokens to be minted
138                let new_tokens = new_tokens_f64.floor() as u64;
139
140                let time_to_return = if new_tokens >= 1 {
141                    // Credit tokens, saturating at max_tokens
142                    let _ = self.tokens.fetch_update(
143                        Ordering::AcqRel,  // writer publishes new amount
144                        Ordering::Acquire, //we fetch the correct amount
145                        |tokens| Some(tokens.saturating_add(new_tokens).min(self.max_tokens)),
146                    );
147                    // Fractional remainder of elapsed time (not enough to mint a whole token)
148                    // that will be credited to other minters
149                    (new_tokens_f64.fract() / self.new_tokens_per_us) as u64
150                } else {
151                    // No whole tokens minted → return whole interval
152                    elapsed
153                };
154                // Save unused elapsed time for other threads
155                self.credit_time_us
156                    .fetch_add(time_to_return, Ordering::Relaxed);
157            }
158            Err(_) => {
159                // Another thread advanced last_update first → nothing we can do now.
160            }
161        }
162    }
163}
164
165impl Clone for TokenBucket {
166    /// Clones the TokenBucket with approximate state
167    /// of the original. While this will never return an object in an
168    /// invalid state, using this in a contended environment is not recommended.
169    fn clone(&self) -> Self {
170        Self {
171            new_tokens_per_us: self.new_tokens_per_us,
172            max_tokens: self.max_tokens,
173            base_time: self.base_time,
174            tokens: AtomicU64::new(self.tokens.load(Ordering::Relaxed)),
175            last_update: AtomicU64::new(self.last_update.load(Ordering::Relaxed)),
176            credit_time_us: AtomicU64::new(self.credit_time_us.load(Ordering::Relaxed)),
177        }
178    }
179}
180
181/// Provides rate limiting for multiple contexts at the same time
182///
183/// This can use e.g. IP address as a Key.
184/// Internally this is a [DashMap] of [TokenBucket] instances
185/// that are created on demand using a prototype [TokenBucket]
186/// to copy initial state from.
187/// Uses LazyLru logic under the hood to keep the amount of items
188/// under control.
189pub struct KeyedRateLimiter<K>
190where
191    K: Hash + Eq,
192{
193    data: DashMap<K, TokenBucket>,
194    target_capacity: usize,
195    prototype_bucket: TokenBucket,
196    countdown_to_shrink: AtomicUsize,
197    approx_len: AtomicUsize,
198    shrink_interval: usize,
199}
200
201impl<K> KeyedRateLimiter<K>
202where
203    K: Hash + Eq,
204{
205    /// Creates a new KeyedRateLimiter with a specified taget capacity and shard amount for the
206    /// underlying DashMap. This uses a LazyLRU style eviction policy, so actual memory consumption
207    /// will be 2 * target_capacity.
208    ///
209    /// shard_amount should be greater than 0 and be a power of two.
210    /// If a shard_amount which is not a power of two is provided, the function will panic.
211    #[allow(clippy::arithmetic_side_effects)]
212    pub fn new(target_capacity: usize, prototype_bucket: TokenBucket, shard_amount: usize) -> Self {
213        let shrink_interval = target_capacity / 4;
214        Self {
215            data: DashMap::with_capacity_and_shard_amount(target_capacity * 2, shard_amount),
216            target_capacity,
217            prototype_bucket,
218            countdown_to_shrink: AtomicUsize::new(shrink_interval),
219            approx_len: AtomicUsize::new(0),
220            shrink_interval,
221        }
222    }
223
224    /// Fetches amount of tokens available for key.
225    ///
226    /// Returns None if no bucket exists for the key provided
227    #[inline]
228    pub fn current_tokens(&self, key: impl Borrow<K>) -> Option<u64> {
229        let bucket = self.data.get(key.borrow())?;
230        Some(bucket.current_tokens())
231    }
232
233    /// Consumes request_size tokens from a bucket at given key.
234    ///
235    /// On success, returns Ok(amount of tokens left in the bucket)
236    /// On failure, returns Err(amount of tokens missing to fill request)
237    /// If no bucket exists at key, a new bucket will be allocated, and normal policy will be applied to it
238    /// Outdated buckets may be evicted on an LRU basis.
239    pub fn consume_tokens(&self, key: K, request_size: u64) -> Result<u64, u64> {
240        let (entry_added, res) = {
241            let bucket = self.data.entry(key);
242            match bucket {
243                Entry::Occupied(entry) => (false, entry.get().consume_tokens(request_size)),
244                Entry::Vacant(entry) => {
245                    // if the key is not in the LRU, we need to allocate a new bucket
246                    let bucket = self.prototype_bucket.clone();
247                    let res = bucket.consume_tokens(request_size);
248                    entry.insert(bucket);
249                    (true, res)
250                }
251            }
252        };
253
254        if entry_added {
255            if let Ok(count) =
256                self.countdown_to_shrink
257                    .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
258                        if v == 0 {
259                            // reset the countup to starting position
260                            // thus preventing other threads from racing for locks
261                            None
262                        } else {
263                            Some(v.saturating_sub(1))
264                        }
265                    })
266            {
267                if count == 1 {
268                    // the last "previous" value we will see before counter reaches zero
269                    self.maybe_shrink();
270                    self.countdown_to_shrink
271                        .store(self.shrink_interval, Ordering::Relaxed);
272                }
273            } else {
274                self.approx_len.fetch_add(1, Ordering::Relaxed);
275            }
276        }
277        res
278    }
279
280    /// Returns approximate amount of entries in the datastructure.
281    /// Should be within ~10% of the true amount.
282    #[inline]
283    pub fn len_approx(&self) -> usize {
284        self.approx_len.load(Ordering::Relaxed)
285    }
286
287    // apply lazy-LRU eviction policy to each DashMap shard.
288    // Allowing side-effects here since overflows here are not
289    // actually possible
290    #[allow(clippy::arithmetic_side_effects)]
291    fn maybe_shrink(&self) {
292        let mut actual_len = 0;
293        let target_shard_size = self.target_capacity / self.data.shards().len();
294        let mut entries = Vec::with_capacity(target_shard_size * 2);
295        for shardlock in self.data.shards() {
296            let mut shard = shardlock.write();
297
298            if shard.len() <= target_shard_size * 3 / 2 {
299                actual_len += shard.len();
300                continue;
301            }
302            entries.clear();
303            entries.extend(
304                shard.drain().map(|(key, value)| {
305                    (key, value.get().last_update.load(Ordering::SeqCst), value)
306                }),
307            );
308
309            entries.select_nth_unstable_by_key(target_shard_size, |(_, last_update, _)| {
310                Reverse(*last_update)
311            });
312
313            shard.extend(
314                entries
315                    .drain(..)
316                    .take(target_shard_size)
317                    .map(|(key, _last_update, value)| (key, value)),
318            );
319            debug_assert!(shard.len() <= target_shard_size);
320            actual_len += shard.len();
321        }
322        self.approx_len.store(actual_len, Ordering::Relaxed);
323    }
324
325    /// Set the auto-shrink interval. Set to 0 to disable shrinking.
326    /// During writes we want to check for length, but not too often
327    /// to reduce probability of lock contention, so keeping this
328    /// large is good for perf (at cost of memory use)
329    pub fn set_shrink_interval(&mut self, interval: usize) {
330        self.shrink_interval = interval;
331    }
332
333    /// Get the auto-shrink interval.
334    pub fn shrink_interval(&self) -> usize {
335        self.shrink_interval
336    }
337}
338
339#[cfg(test)]
340pub mod test {
341    use {
342        super::*,
343        solana_svm_type_overrides::thread,
344        std::{
345            net::{IpAddr, Ipv4Addr},
346            time::Duration,
347        },
348    };
349
350    #[test]
351    fn test_token_bucket() {
352        let tb = TokenBucket::new(100, 100, 1000.0);
353        assert_eq!(tb.current_tokens(), 100);
354        tb.consume_tokens(50).expect("Bucket is initially full");
355        tb.consume_tokens(50)
356            .expect("We should still have >50 tokens left");
357        tb.consume_tokens(50)
358            .expect_err("There should not be enough tokens now");
359        thread::sleep(Duration::from_millis(50));
360        assert!(
361            tb.current_tokens() > 40,
362            "We should be refilling at ~1 token per millisecond"
363        );
364        assert!(
365            tb.current_tokens() < 70,
366            "We should be refilling at ~1 token per millisecond"
367        );
368        tb.consume_tokens(40)
369            .expect("Bucket should have enough for another request now");
370        thread::sleep(Duration::from_millis(120));
371        assert_eq!(tb.current_tokens(), 100, "Bucket should not overfill");
372    }
373    #[test]
374    fn test_keyed_rate_limiter() {
375        let prototype_bucket = TokenBucket::new(100, 100, 1000.0);
376        let rl = KeyedRateLimiter::new(8, prototype_bucket, 2);
377        let ip1 = IpAddr::V4(Ipv4Addr::from_bits(1234));
378        let ip2 = IpAddr::V4(Ipv4Addr::from_bits(4321));
379        assert_eq!(rl.current_tokens(ip1), None, "Initially no buckets exist");
380        rl.consume_tokens(ip1, 50)
381            .expect("Bucket is initially full");
382        rl.consume_tokens(ip1, 50)
383            .expect("We should still have >50 tokens left");
384        rl.consume_tokens(ip1, 50)
385            .expect_err("There should not be enough tokens now");
386        rl.consume_tokens(ip2, 50)
387            .expect("Bucket is initially full");
388        rl.consume_tokens(ip2, 50)
389            .expect("We should still have >50 tokens left");
390        rl.consume_tokens(ip2, 50)
391            .expect_err("There should not be enough tokens now");
392        std::thread::sleep(Duration::from_millis(50));
393        assert!(
394            rl.current_tokens(ip1).unwrap() > 40,
395            "We should be refilling at ~1 token per millisecond"
396        );
397        assert!(
398            rl.current_tokens(ip1).unwrap() < 70,
399            "We should be refilling at ~1 token per millisecond"
400        );
401        rl.consume_tokens(ip1, 40)
402            .expect("Bucket should have enough for another request now");
403        thread::sleep(Duration::from_millis(120));
404        assert_eq!(
405            rl.current_tokens(ip1),
406            Some(100),
407            "Bucket should not overfill"
408        );
409        assert_eq!(
410            rl.current_tokens(ip2),
411            Some(100),
412            "Bucket should not overfill"
413        );
414
415        rl.consume_tokens(ip2, 100).expect("Bucket should be full");
416        // go several times over the capacity of the TB to make sure old record
417        // is erased no matter in which bucket it lands
418        for ip in 0..64 {
419            let ip = IpAddr::V4(Ipv4Addr::from_bits(ip));
420            rl.consume_tokens(ip, 50).unwrap();
421        }
422        assert_eq!(
423            rl.current_tokens(ip1),
424            None,
425            "Very old record should have been erased"
426        );
427        rl.consume_tokens(ip2, 100)
428            .expect("New bucket should have been made for ip2");
429    }
430
431    #[cfg(feature = "shuttle-test")]
432    #[test]
433    fn shuttle_test_token_bucket_race() {
434        use shuttle::sync::atomic::AtomicBool;
435        shuttle::check_random(
436            || {
437                TIME_US.store(0, Ordering::SeqCst);
438                let test_duration_us = 2500;
439                let run: &AtomicBool = Box::leak(Box::new(AtomicBool::new(true)));
440                let tb: &TokenBucket = Box::leak(Box::new(TokenBucket::new(10, 20, 5000.0)));
441
442                // time advancement thread
443                let time_advancer = thread::spawn(move || {
444                    let mut current_time = 0;
445                    while current_time < test_duration_us && run.load(Ordering::SeqCst) {
446                        let increment = 100; // microseconds
447                        current_time += increment;
448                        TIME_US.store(current_time, Ordering::SeqCst);
449                        shuttle::thread::yield_now();
450                    }
451                    run.store(false, Ordering::SeqCst);
452                });
453
454                let threads: Vec<_> = (0..2)
455                    .map(|_| {
456                        thread::spawn(move || {
457                            let mut total = 0;
458                            while run.load(Ordering::SeqCst) {
459                                if tb.consume_tokens(5).is_ok() {
460                                    total += 1;
461                                }
462                                shuttle::thread::yield_now();
463                            }
464                            total
465                        })
466                    })
467                    .collect();
468
469                time_advancer.join().unwrap();
470                let received = threads.into_iter().map(|t| t.join().unwrap()).sum();
471
472                // Initial tokens: 10, refill rate: 5000 tokens/sec (5 tokens/ms)
473                // In 2ms: 10 + (5 * 2) = 20 tokens total
474                // Each consumption: 5 tokens → 4 total consumptions expected
475                assert_eq!(4, received);
476            },
477            100,
478        );
479    }
480}