throttle_stream/
lib.rs

1use std::time::Duration;
2use std::sync::Arc;
3use futures::Stream;
4use futures::stream::StreamExt;
5use tokio::sync::{Mutex, mpsc};
6use tokio::task::JoinHandle;
7use tokio::time::sleep;
8
9pub struct ThrottleStream<T> {
10    duration: Duration,
11    buffer: Arc<Mutex<Vec<T>>>,
12    task: Arc<Mutex<Option<JoinHandle<()>>>>,
13    sender: mpsc::Sender<Vec<T>>,
14}
15
16impl<T: Send + 'static> ThrottleStream<T> {
17    pub fn clone(&self) -> Self {
18        Self {
19            duration: self.duration,
20            buffer: self.buffer.clone(),
21            task: self.task.clone(),
22            sender: self.sender.clone(),
23        }
24    }
25
26    pub fn new(duration: Duration) -> (Self, mpsc::Receiver<Vec<T>>) {
27        let (sender, receiver) = mpsc::channel(1);
28        let buffer = Arc::new(Mutex::new(Vec::new()));
29        let task = Arc::new(Mutex::new(None));
30        (
31            Self {
32                duration,
33                buffer,
34                task,
35                sender,
36            },
37            receiver,
38        )
39    }
40
41    async fn send_data(&self, data: Vec<T>, task_guard: &mut Option<JoinHandle<()>>) {
42        if let Err(e) = self.sender.send(data).await {
43            eprintln!("Failed to send data: {}", e);
44            return;
45        }
46        *task_guard = Some(self.schedule_task());
47    }
48
49    fn schedule_task(&self) -> JoinHandle<()> {
50        let this = self.clone();
51        tokio::spawn(async move {
52            sleep(this.duration).await;
53
54            let mut buffer_guard = this.buffer.lock().await;
55            let mut task_guard = this.task.lock().await;
56            if buffer_guard.is_empty() {
57                *task_guard = None;
58                return;
59            }
60
61            let data = buffer_guard.drain(..).collect::<Vec<_>>();
62            this.send_data(data, &mut task_guard).await;
63        })
64    }
65
66    pub async fn push(&self, item: T) {
67        let mut task_guard = self.task.lock().await;
68        if task_guard.is_some() {
69            let mut buffer_guard = self.buffer.lock().await;
70            buffer_guard.push(item);
71        } else {
72            self.send_data(vec![item], &mut task_guard).await;
73        }
74    }
75
76    pub async fn drain(&self) {
77        let mut buffer_guard = self.buffer.lock().await;
78        let mut task_guard = self.task.lock().await;
79        if buffer_guard.is_empty() {
80            return;
81        }
82
83        let data = buffer_guard.drain(..).collect::<Vec<_>>();
84        self.send_data(data, &mut task_guard).await;
85    }
86}
87
88pub fn throttle_stream<S, T>(
89    input_stream: S,
90    duration: Duration,
91) -> impl Stream<Item = Vec<T>>
92where
93    S: Stream<Item = T> + Send + 'static,
94    T: Send + 'static,
95{
96    let (throttle, mut receiver) = ThrottleStream::new(duration);
97    tokio::spawn(async move {
98        let mut input_stream = Box::pin(input_stream);
99        while let Some(item) = input_stream.next().await {
100            throttle.push(item).await;
101        }
102    });
103    async_stream::stream! {
104        while let Some(data) = receiver.recv().await {
105            yield data;
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[tokio::test]
115    async fn test_throttle_stream() {
116        let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
117
118        throttle.push(1).await;
119        // time: 0ms
120        assert_eq!(receiver.try_recv(), Ok(vec![1]));
121
122        throttle.push(2).await;
123        throttle.push(3).await;
124        // time: 0ms
125        assert!(receiver.try_recv().is_err());
126
127        sleep(Duration::from_millis(120)).await;
128        // time: 120ms
129        assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
130
131        throttle.push(4).await;
132        throttle.push(5).await;
133        // time: 120ms
134        assert!(receiver.try_recv().is_err());
135
136        sleep(Duration::from_millis(120)).await;
137        // time: 240ms
138        assert_eq!(receiver.recv().await, Some(vec![4, 5]));
139    }
140
141    #[tokio::test]
142    async fn test_throttle_stream_flush() {
143        let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
144
145        throttle.push(1).await;
146        // time: 0ms
147        assert_eq!(receiver.try_recv(), Ok(vec![1]));
148
149        throttle.push(2).await;
150        throttle.push(3).await;
151        // time: 0ms
152        assert!(receiver.try_recv().is_err());
153
154        throttle.drain().await;
155        // time: 0ms
156        assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
157
158        throttle.drain().await;
159        // time: 0ms
160        assert!(receiver.try_recv().is_err());
161
162        throttle.push(4).await;
163        throttle.push(5).await;
164        // time: 0ms
165        assert!(receiver.try_recv().is_err());
166
167        throttle.drain().await;
168        // time: 0ms
169        assert_eq!(receiver.try_recv(), Ok(vec![4, 5]));
170    }
171}