stdout_channel/
rate_limiter.rs1use std::sync::{
2 atomic::{AtomicUsize, Ordering},
3 Arc,
4};
5use tokio::{
6 sync::Notify,
7 task::{spawn, JoinHandle},
8 time::{sleep, Duration},
9};
10
11#[derive(Clone)]
12pub struct RateLimiter {
13 inner: Arc<RateLimiterInner>,
14 #[allow(dead_code)]
15 rate_task: Arc<JoinHandle<()>>,
16}
17
18impl RateLimiter {
19 #[must_use]
20 pub fn new(max_per_unit_time: usize, unit_time_ms: usize) -> Self {
21 let inner = Arc::new(RateLimiterInner::new(max_per_unit_time, unit_time_ms));
22 let rate_task = Arc::new({
23 let inner = inner.clone();
24 spawn(async move {
25 inner.check_reset().await;
26 })
27 });
28 Self { inner, rate_task }
29 }
30
31 pub async fn acquire(&self) {
32 self.inner.acquire().await;
33 }
34}
35
36struct RateLimiterInner {
37 max_per_unit_time: usize,
38 unit_time_ms: usize,
39 remaining: AtomicUsize,
40 notify: Notify,
41}
42
43impl RateLimiterInner {
44 fn new(max_per_unit_time: usize, unit_time_ms: usize) -> Self {
45 Self {
46 max_per_unit_time,
47 unit_time_ms,
48 remaining: AtomicUsize::new(max_per_unit_time),
49 notify: Notify::new(),
50 }
51 }
52
53 fn decrement_remaining(&self) -> bool {
54 fn gtzero(x: usize) -> Option<usize> {
55 if x > 0 {
56 Some(x - 1)
57 } else {
58 None
59 }
60 }
61
62 self.remaining
63 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, gtzero)
64 .is_ok()
65 }
66
67 async fn acquire(&self) {
68 loop {
69 if self.decrement_remaining() {
70 return;
71 }
72 self.notify.notified().await;
73 }
74 }
75
76 async fn check_reset(&self) {
77 loop {
78 self.remaining
79 .fetch_max(self.max_per_unit_time, Ordering::SeqCst);
80 self.notify.notify_waiters();
81 sleep(Duration::from_millis(self.unit_time_ms as u64)).await;
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use log::debug;
89 use std::sync::{
90 atomic::{AtomicUsize, Ordering},
91 Arc,
92 };
93 use time::OffsetDateTime;
94 use tokio::{
95 task::spawn,
96 time::{sleep, Duration},
97 };
98
99 use crate::{rate_limiter::RateLimiter, StdoutChannelError};
100
101 #[tokio::test]
102 async fn test_rate_limiter() -> Result<(), StdoutChannelError> {
103 env_logger::init();
104
105 let start = OffsetDateTime::now_utc();
106
107 let rate_limiter = RateLimiter::new(1000, 100);
108 let test_count = Arc::new(AtomicUsize::new(0));
109
110 let tasks: Vec<_> = (0..10_000)
111 .map(|_| {
112 let rate_limiter = rate_limiter.clone();
113 let test_count = test_count.clone();
114 spawn(async move {
115 rate_limiter.acquire().await;
116 test_count.fetch_add(1, Ordering::SeqCst);
117 })
118 })
119 .collect();
120
121 sleep(Duration::from_millis(100)).await;
122
123 for _ in 0..5 {
124 let count = test_count.load(Ordering::SeqCst);
125 debug!("{}", count);
126 sleep(Duration::from_millis(100)).await;
127 }
128 for t in tasks {
129 t.await?;
130 }
131
132 let elapsed = OffsetDateTime::now_utc() - start;
133
134 println!(
135 "{} {}",
136 elapsed.whole_milliseconds(),
137 test_count.load(Ordering::SeqCst)
138 );
139 assert!(elapsed.whole_milliseconds() >= 900);
140 assert_eq!(test_count.load(Ordering::SeqCst), 10_000);
141 Ok(())
142 }
143}