tokio_splice2/
rate.rs

1//! Token bucket rate limiter for bytes transfer rate limitation.
2
3#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::cast_sign_loss)]
5#![allow(clippy::cast_possible_truncation)]
6
7use std::cmp::min;
8use std::num::{NonZeroU16, NonZeroU64, NonZeroUsize};
9use std::sync::atomic::{AtomicU16, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use std::{fmt, io};
13
14use crossbeam_utils::CachePadded;
15use tokio::time::Instant;
16
17const TOKIO_TIMER_MIN_DUR: Duration = Duration::from_millis(1);
18
19pub(crate) const RATE_LIMITER_ENABLED: bool = true;
20pub(crate) const RATE_LIMITER_DISABLED: bool = false;
21
22#[derive(Debug, Clone)]
23/// Bytes transfer rate limitation, `B/s`.
24///
25/// The limitation can be shared by multiple connections, see
26/// [`RateLimit::new_shared_by`] or [`RateLimit::set_share_by`].
27///
28/// # Notes
29///
30/// - (WIP) Not so that accurate, as Tokio's timer resolution is limited to 1ms.
31pub struct RateLimit {
32    total: Arc<AtomicU64>,
33    shared_by: Arc<AtomicU16>,
34}
35
36impl RateLimit {
37    const DISABLED: u64 = 0;
38
39    /// Create a new [`RateLimit`].
40    #[must_use]
41    pub fn new(limit: NonZeroU64) -> Self {
42        Self {
43            total: Arc::new(AtomicU64::new(limit.get())),
44            shared_by: Arc::new(AtomicU16::new(1)),
45        }
46    }
47
48    /// Create a new [`RateLimit`] that is disabled (total = 0).
49    #[must_use]
50    pub fn new_disabled() -> Self {
51        Self {
52            total: Arc::new(AtomicU64::new(Self::DISABLED)),
53            shared_by: Arc::new(AtomicU16::new(1)),
54        }
55    }
56
57    /// Create a new [`RateLimit`] that is shared by `N` instances.
58    ///
59    /// # Panics
60    ///
61    /// If `N` is 0.
62    #[must_use]
63    pub fn new_shared_by<const N: u16>(limit: NonZeroU64) -> Self {
64        Self {
65            total: Arc::new(AtomicU64::new(limit.get())),
66            shared_by: Arc::new(AtomicU16::new(
67                NonZeroU16::new(N).expect("`shared_by cannot be 0`").get(),
68            )),
69        }
70    }
71
72    /// Get the current limit for single connection.
73    #[must_use]
74    pub fn current(&self) -> Option<NonZeroU64> {
75        let total = self.total.load(Ordering::Relaxed);
76
77        if total == 0 {
78            None
79        } else {
80            let shared_by = self.shared_by.load(Ordering::Relaxed);
81
82            NonZeroU64::new((total + u64::from(shared_by)) / u64::from(shared_by))
83        }
84    }
85
86    /// Disable the rate limit (by set to 0).
87    pub fn set_disable(&self) {
88        self.total.store(0, Ordering::Release);
89    }
90
91    /// Set the total rate limit.
92    pub fn set_total(&self, limit: NonZeroU64) {
93        self.total.store(limit.get(), Ordering::Release);
94    }
95
96    /// Update the number of shared instances.
97    pub fn set_share_by(&self, shared_by: NonZeroU16) {
98        self.shared_by.store(shared_by.get(), Ordering::Release);
99    }
100
101    /// Increase the number of shared instances by `N`.
102    ///
103    /// # Panics
104    ///
105    /// If `N` is 0.
106    pub fn inc_shared_by_n<const N: u16>(&self) {
107        self.inc_shared_by({
108            NonZeroU16::new(N).expect("`inc_shared_by_n` cannot be called with 0")
109        });
110    }
111
112    /// Increase the number of shared instances.
113    pub fn inc_shared_by(&self, inc: NonZeroU16) {
114        let _ = self
115            .shared_by
116            .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
117                Some(current.saturating_add(inc.get()))
118            });
119    }
120
121    /// Decrease the number of shared instances by `N`.
122    ///
123    /// ## Panics
124    ///
125    /// If `N` is 0.
126    ///
127    /// ## Errors
128    ///
129    /// See [`dec_shared_by`](Self::dec_shared_by).
130    pub fn dec_shared_by_n<const N: u16>(&self) -> io::Result<()> {
131        self.dec_shared_by({
132            NonZeroU16::new(N).expect("`dec_shared_by_n` cannot be called with 0")
133        })
134    }
135
136    /// Decrease the number of shared instances.
137    ///
138    /// ## Errors
139    ///
140    /// Cannot decrease `shared_by` to 0.
141    pub fn dec_shared_by(&self, dec: NonZeroU16) -> io::Result<()> {
142        #[allow(clippy::redundant_closure_for_method_calls)]
143        self.shared_by
144            .fetch_update(Ordering::Release, Ordering::Acquire, |shared_by| {
145                shared_by
146                    .checked_sub(dec.get())
147                    .and_then(NonZeroU16::new)
148                    .map(|s| s.get())
149            })
150            .map(|_| ())
151            .map_err(|_| {
152                io::Error::new(
153                    io::ErrorKind::InvalidInput,
154                    "cannot decrease `shared_by` to 0",
155                )
156            })
157    }
158
159    /// Clone the `RateLimit` with the same total limit, and increase the shared
160    /// instances count by 1.
161    #[must_use]
162    pub fn clone_shared(&self) -> Self {
163        let new_limit = self.clone();
164        new_limit.inc_shared_by_n::<1>();
165        new_limit
166    }
167}
168
169pub(crate) struct RateLimiter<const ENABLED: bool> {
170    /// Bytes transfer rate limitation, `B/s`.
171    limit: CachePadded<RateLimit>,
172
173    /// Available tokens (in Bytes).
174    tokens: Option<f64>,
175
176    /// The last time the tokens were updated.
177    last_updated: Option<Instant>,
178}
179
180impl<const ENABLED: bool> fmt::Debug for RateLimiter<ENABLED> {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        if ENABLED {
183            f.debug_struct("RateLimiter")
184                .field("enabled", &ENABLED)
185                .field("limit", &self.limit)
186                .field("tokens", &self.tokens)
187                .field(
188                    "since_last_updated",
189                    &self.last_updated.map(|i| i.elapsed()),
190                )
191                .finish()
192        } else {
193            f.debug_struct("RateLimiter")
194                .field("enabled", &ENABLED)
195                .finish()
196        }
197    }
198}
199
200/// Rate limit result.
201pub(crate) enum RateLimitResult {
202    /// The requested number of bytes is accepted.
203    Accepted,
204
205    /// The requested number of bytes is throttled, should stop reading for the
206    /// specified duration.
207    Throttled { now: Instant, dur: Duration },
208}
209
210impl fmt::Debug for RateLimitResult {
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        match self {
213            RateLimitResult::Accepted => f.write_str("Accepted"),
214            RateLimitResult::Throttled { dur, .. } => {
215                f.debug_tuple("Throttled").field(dur).finish()
216            }
217        }
218    }
219}
220
221impl RateLimiter<RATE_LIMITER_DISABLED> {
222    /// Create a new rate limiter that is disabled.
223    pub(crate) fn empty() -> Self {
224        Self::new(RateLimit::new_disabled())
225    }
226}
227
228impl<const ENABLED: bool> RateLimiter<ENABLED> {
229    /// Create a new rate limiter with the specified [`RateLimit`].
230    pub(crate) const fn new(limit: RateLimit) -> Self {
231        Self {
232            limit: CachePadded::new(limit),
233            tokens: None,
234            last_updated: None,
235        }
236    }
237
238    #[inline]
239    #[cfg_attr(
240        any(
241            feature = "feat-tracing-trace",
242            all(debug_assertions, feature = "feat-tracing")
243        ),
244        tracing::instrument(level = "TRACE", ret)
245    )]
246    /// Returns the ideal length of target splice ken based on the current rate
247    /// limit.
248    pub(crate) fn ideal_len(&self, pipe_size: NonZeroUsize) -> Option<NonZeroUsize> {
249        let Some(limit) = self.limit.current() else {
250            // Rate limiter is disabled, return the pipe size.
251            return None;
252        };
253
254        let ideal_len = min(
255            (limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64()).ceil() as usize * 2,
256            pipe_size.get(),
257        );
258
259        #[allow(unsafe_code)]
260        // `ideal_len` is guaranteed to be non-zero.
261        Some(unsafe { NonZeroUsize::new_unchecked(ideal_len) })
262    }
263
264    #[cfg_attr(
265        any(
266            feature = "feat-tracing-trace",
267            all(debug_assertions, feature = "feat-tracing")
268        ),
269        tracing::instrument(level = "TRACE", ret)
270    )]
271    /// Check if the rate limiter allows the request to proceed, or sleep.
272    pub(crate) fn check(&mut self, has_read: NonZeroUsize) -> RateLimitResult {
273        if !ENABLED {
274            // Rate limiter is disabled, always accept.
275            return RateLimitResult::Accepted;
276        }
277
278        let Some(limit) = self.limit.current() else {
279            // Rate limiter is disabled, always accept.
280            self.tokens = None;
281            self.last_updated = None;
282
283            return RateLimitResult::Accepted;
284        };
285
286        let now = Instant::now();
287
288        let Some(ref mut last_updated) = self.last_updated else {
289            // Initialize the last updated time.
290            self.last_updated = Some(now);
291
292            return RateLimitResult::Accepted;
293        };
294
295        let current_tokens = if let Some(ref mut tokens) = self.tokens {
296            tokens
297        } else {
298            // Initialize the tokens and last updated time.
299            self.tokens = Some(limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64());
300            self.tokens.as_mut().unwrap()
301        };
302
303        // Refill the tokens bucket
304        Self::refill(current_tokens, now, last_updated, limit);
305
306        // Try to acquire the tokens.
307        {
308            *current_tokens -= has_read.get() as f64;
309
310            if current_tokens.is_sign_negative() {
311                let insufficient_tokens = current_tokens.abs();
312
313                return RateLimitResult::Throttled {
314                    now,
315                    dur: Duration::from_secs_f64(
316                        (insufficient_tokens / limit.get() as f64).floor(),
317                    )
318                    .max(TOKIO_TIMER_MIN_DUR), // or Tokio may sleep forever
319                };
320            }
321        }
322
323        RateLimitResult::Accepted
324    }
325
326    #[inline]
327    /// Refill the token bucket with the elapsed time since the last update.
328    fn refill(tokens: &mut f64, now: Instant, last_updated: &mut Instant, limit: NonZeroU64) {
329        let Some(elapsed) = now.checked_duration_since(*last_updated) else {
330            // The last update is in the future, force update.
331            *last_updated = now;
332
333            return;
334        };
335
336        *last_updated = now;
337
338        let new_tokens = *tokens + (limit.get() as f64 * elapsed.as_secs_f64());
339        let max_new_tokens = limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64();
340
341        // Enforce the tokens to be in the range of [0, limit * 0.01s].
342        if new_tokens.is_normal() {
343            // After a long time with few tokens acquired, we will have a large amount of
344            // tokens accumulated, this is not what we want.
345
346            *tokens = if max_new_tokens <= new_tokens {
347                max_new_tokens
348            } else {
349                new_tokens
350            };
351        } else {
352            // Rare.
353            *tokens = max_new_tokens;
354        }
355    }
356}