stdout_channel/
rate_limiter.rs

1use std::sync::{
2    atomic::{AtomicUsize, Ordering},
3    Arc,
4};
5use tokio::{
6    sync::Notify,
7    task::{spawn, JoinHandle},
8    time::{sleep, Duration},
9};
10
11#[derive(Clone)]
12pub struct RateLimiter {
13    inner: Arc<RateLimiterInner>,
14    #[allow(dead_code)]
15    rate_task: Arc<JoinHandle<()>>,
16}
17
18impl RateLimiter {
19    #[must_use]
20    pub fn new(max_per_unit_time: usize, unit_time_ms: usize) -> Self {
21        let inner = Arc::new(RateLimiterInner::new(max_per_unit_time, unit_time_ms));
22        let rate_task = Arc::new({
23            let inner = inner.clone();
24            spawn(async move {
25                inner.check_reset().await;
26            })
27        });
28        Self { inner, rate_task }
29    }
30
31    pub async fn acquire(&self) {
32        self.inner.acquire().await;
33    }
34}
35
36struct RateLimiterInner {
37    max_per_unit_time: usize,
38    unit_time_ms: usize,
39    remaining: AtomicUsize,
40    notify: Notify,
41}
42
43impl RateLimiterInner {
44    fn new(max_per_unit_time: usize, unit_time_ms: usize) -> Self {
45        Self {
46            max_per_unit_time,
47            unit_time_ms,
48            remaining: AtomicUsize::new(max_per_unit_time),
49            notify: Notify::new(),
50        }
51    }
52
53    fn decrement_remaining(&self) -> bool {
54        fn gtzero(x: usize) -> Option<usize> {
55            if x > 0 {
56                Some(x - 1)
57            } else {
58                None
59            }
60        }
61
62        self.remaining
63            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, gtzero)
64            .is_ok()
65    }
66
67    async fn acquire(&self) {
68        loop {
69            if self.decrement_remaining() {
70                return;
71            }
72            self.notify.notified().await;
73        }
74    }
75
76    async fn check_reset(&self) {
77        loop {
78            self.remaining
79                .fetch_max(self.max_per_unit_time, Ordering::SeqCst);
80            self.notify.notify_waiters();
81            sleep(Duration::from_millis(self.unit_time_ms as u64)).await;
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use log::debug;
89    use std::sync::{
90        atomic::{AtomicUsize, Ordering},
91        Arc,
92    };
93    use time::OffsetDateTime;
94    use tokio::{
95        task::spawn,
96        time::{sleep, Duration},
97    };
98
99    use crate::{rate_limiter::RateLimiter, StdoutChannelError};
100
101    #[tokio::test]
102    async fn test_rate_limiter() -> Result<(), StdoutChannelError> {
103        env_logger::init();
104
105        let start = OffsetDateTime::now_utc();
106
107        let rate_limiter = RateLimiter::new(1000, 100);
108        let test_count = Arc::new(AtomicUsize::new(0));
109
110        let tasks: Vec<_> = (0..10_000)
111            .map(|_| {
112                let rate_limiter = rate_limiter.clone();
113                let test_count = test_count.clone();
114                spawn(async move {
115                    rate_limiter.acquire().await;
116                    test_count.fetch_add(1, Ordering::SeqCst);
117                })
118            })
119            .collect();
120
121        sleep(Duration::from_millis(100)).await;
122
123        for _ in 0..5 {
124            let count = test_count.load(Ordering::SeqCst);
125            debug!("{}", count);
126            sleep(Duration::from_millis(100)).await;
127        }
128        for t in tasks {
129            t.await?;
130        }
131
132        let elapsed = OffsetDateTime::now_utc() - start;
133
134        println!(
135            "{} {}",
136            elapsed.whole_milliseconds(),
137            test_count.load(Ordering::SeqCst)
138        );
139        assert!(elapsed.whole_milliseconds() >= 900);
140        assert_eq!(test_count.load(Ordering::SeqCst), 10_000);
141        Ok(())
142    }
143}