1#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::cast_sign_loss)]
5#![allow(clippy::cast_possible_truncation)]
6
7use std::cmp::min;
8use std::num::{NonZeroU16, NonZeroU64, NonZeroUsize};
9use std::sync::atomic::{AtomicU16, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use std::{fmt, io};
13
14use crossbeam_utils::CachePadded;
15use tokio::time::Instant;
16
17const TOKIO_TIMER_MIN_DUR: Duration = Duration::from_millis(1);
18
19pub(crate) const RATE_LIMITER_ENABLED: bool = true;
20pub(crate) const RATE_LIMITER_DISABLED: bool = false;
21
22#[derive(Debug, Clone)]
23pub struct RateLimit {
32 total: Arc<AtomicU64>,
33 shared_by: Arc<AtomicU16>,
34}
35
36impl RateLimit {
37 const DISABLED: u64 = 0;
38
39 #[must_use]
41 pub fn new(limit: NonZeroU64) -> Self {
42 Self {
43 total: Arc::new(AtomicU64::new(limit.get())),
44 shared_by: Arc::new(AtomicU16::new(1)),
45 }
46 }
47
48 #[must_use]
50 pub fn new_disabled() -> Self {
51 Self {
52 total: Arc::new(AtomicU64::new(Self::DISABLED)),
53 shared_by: Arc::new(AtomicU16::new(1)),
54 }
55 }
56
57 #[must_use]
63 pub fn new_shared_by<const N: u16>(limit: NonZeroU64) -> Self {
64 Self {
65 total: Arc::new(AtomicU64::new(limit.get())),
66 shared_by: Arc::new(AtomicU16::new(
67 NonZeroU16::new(N).expect("`shared_by cannot be 0`").get(),
68 )),
69 }
70 }
71
72 #[must_use]
74 pub fn current(&self) -> Option<NonZeroU64> {
75 let total = self.total.load(Ordering::Relaxed);
76
77 if total == 0 {
78 None
79 } else {
80 let shared_by = self.shared_by.load(Ordering::Relaxed);
81
82 NonZeroU64::new((total + u64::from(shared_by)) / u64::from(shared_by))
83 }
84 }
85
86 pub fn set_disable(&self) {
88 self.total.store(0, Ordering::Release);
89 }
90
91 pub fn set_total(&self, limit: NonZeroU64) {
93 self.total.store(limit.get(), Ordering::Release);
94 }
95
96 pub fn set_share_by(&self, shared_by: NonZeroU16) {
98 self.shared_by.store(shared_by.get(), Ordering::Release);
99 }
100
101 pub fn inc_shared_by_n<const N: u16>(&self) {
107 self.inc_shared_by({
108 NonZeroU16::new(N).expect("`inc_shared_by_n` cannot be called with 0")
109 });
110 }
111
112 pub fn inc_shared_by(&self, inc: NonZeroU16) {
114 let _ = self
115 .shared_by
116 .fetch_update(Ordering::Release, Ordering::Acquire, |current| {
117 Some(current.saturating_add(inc.get()))
118 });
119 }
120
121 pub fn dec_shared_by_n<const N: u16>(&self) -> io::Result<()> {
131 self.dec_shared_by({
132 NonZeroU16::new(N).expect("`dec_shared_by_n` cannot be called with 0")
133 })
134 }
135
136 pub fn dec_shared_by(&self, dec: NonZeroU16) -> io::Result<()> {
142 #[allow(clippy::redundant_closure_for_method_calls)]
143 self.shared_by
144 .fetch_update(Ordering::Release, Ordering::Acquire, |shared_by| {
145 shared_by
146 .checked_sub(dec.get())
147 .and_then(NonZeroU16::new)
148 .map(|s| s.get())
149 })
150 .map(|_| ())
151 .map_err(|_| {
152 io::Error::new(
153 io::ErrorKind::InvalidInput,
154 "cannot decrease `shared_by` to 0",
155 )
156 })
157 }
158
159 #[must_use]
162 pub fn clone_shared(&self) -> Self {
163 let new_limit = self.clone();
164 new_limit.inc_shared_by_n::<1>();
165 new_limit
166 }
167}
168
169pub(crate) struct RateLimiter<const ENABLED: bool> {
170 limit: CachePadded<RateLimit>,
172
173 tokens: Option<f64>,
175
176 last_updated: Option<Instant>,
178}
179
180impl<const ENABLED: bool> fmt::Debug for RateLimiter<ENABLED> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 if ENABLED {
183 f.debug_struct("RateLimiter")
184 .field("enabled", &ENABLED)
185 .field("limit", &self.limit)
186 .field("tokens", &self.tokens)
187 .field(
188 "since_last_updated",
189 &self.last_updated.map(|i| i.elapsed()),
190 )
191 .finish()
192 } else {
193 f.debug_struct("RateLimiter")
194 .field("enabled", &ENABLED)
195 .finish()
196 }
197 }
198}
199
200pub(crate) enum RateLimitResult {
202 Accepted,
204
205 Throttled { now: Instant, dur: Duration },
208}
209
210impl fmt::Debug for RateLimitResult {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 match self {
213 RateLimitResult::Accepted => f.write_str("Accepted"),
214 RateLimitResult::Throttled { dur, .. } => {
215 f.debug_tuple("Throttled").field(dur).finish()
216 }
217 }
218 }
219}
220
221impl RateLimiter<RATE_LIMITER_DISABLED> {
222 pub(crate) fn empty() -> Self {
224 Self::new(RateLimit::new_disabled())
225 }
226}
227
228impl<const ENABLED: bool> RateLimiter<ENABLED> {
229 pub(crate) const fn new(limit: RateLimit) -> Self {
231 Self {
232 limit: CachePadded::new(limit),
233 tokens: None,
234 last_updated: None,
235 }
236 }
237
238 #[inline]
239 #[cfg_attr(
240 any(
241 feature = "feat-tracing-trace",
242 all(debug_assertions, feature = "feat-tracing")
243 ),
244 tracing::instrument(level = "TRACE", ret)
245 )]
246 pub(crate) fn ideal_len(&self, pipe_size: NonZeroUsize) -> Option<NonZeroUsize> {
249 let Some(limit) = self.limit.current() else {
250 return None;
252 };
253
254 let ideal_len = min(
255 (limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64()).ceil() as usize * 2,
256 pipe_size.get(),
257 );
258
259 #[allow(unsafe_code)]
260 Some(unsafe { NonZeroUsize::new_unchecked(ideal_len) })
262 }
263
264 #[cfg_attr(
265 any(
266 feature = "feat-tracing-trace",
267 all(debug_assertions, feature = "feat-tracing")
268 ),
269 tracing::instrument(level = "TRACE", ret)
270 )]
271 pub(crate) fn check(&mut self, has_read: NonZeroUsize) -> RateLimitResult {
273 if !ENABLED {
274 return RateLimitResult::Accepted;
276 }
277
278 let Some(limit) = self.limit.current() else {
279 self.tokens = None;
281 self.last_updated = None;
282
283 return RateLimitResult::Accepted;
284 };
285
286 let now = Instant::now();
287
288 let Some(ref mut last_updated) = self.last_updated else {
289 self.last_updated = Some(now);
291
292 return RateLimitResult::Accepted;
293 };
294
295 let current_tokens = if let Some(ref mut tokens) = self.tokens {
296 tokens
297 } else {
298 self.tokens = Some(limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64());
300 self.tokens.as_mut().unwrap()
301 };
302
303 Self::refill(current_tokens, now, last_updated, limit);
305
306 {
308 *current_tokens -= has_read.get() as f64;
309
310 if current_tokens.is_sign_negative() {
311 let insufficient_tokens = current_tokens.abs();
312
313 return RateLimitResult::Throttled {
314 now,
315 dur: Duration::from_secs_f64(
316 (insufficient_tokens / limit.get() as f64).floor(),
317 )
318 .max(TOKIO_TIMER_MIN_DUR), };
320 }
321 }
322
323 RateLimitResult::Accepted
324 }
325
326 #[inline]
327 fn refill(tokens: &mut f64, now: Instant, last_updated: &mut Instant, limit: NonZeroU64) {
329 let Some(elapsed) = now.checked_duration_since(*last_updated) else {
330 *last_updated = now;
332
333 return;
334 };
335
336 *last_updated = now;
337
338 let new_tokens = *tokens + (limit.get() as f64 * elapsed.as_secs_f64());
339 let max_new_tokens = limit.get() as f64 * TOKIO_TIMER_MIN_DUR.as_secs_f64();
340
341 if new_tokens.is_normal() {
343 *tokens = if max_new_tokens <= new_tokens {
347 max_new_tokens
348 } else {
349 new_tokens
350 };
351 } else {
352 *tokens = max_new_tokens;
354 }
355 }
356}