1use std::time::Duration;
2use std::sync::Arc;
3use futures::Stream;
4use futures::stream::StreamExt;
5use tokio::sync::{Mutex, mpsc};
6use tokio::task::JoinHandle;
7use tokio::time::sleep;
8
9pub struct ThrottleStream<T> {
10 duration: Duration,
11 buffer: Arc<Mutex<Vec<T>>>,
12 task: Arc<Mutex<Option<JoinHandle<()>>>>,
13 sender: mpsc::Sender<Vec<T>>,
14}
15
16impl<T: Send + 'static> ThrottleStream<T> {
17 pub fn clone(&self) -> Self {
18 Self {
19 duration: self.duration,
20 buffer: self.buffer.clone(),
21 task: self.task.clone(),
22 sender: self.sender.clone(),
23 }
24 }
25
26 pub fn new(duration: Duration) -> (Self, mpsc::Receiver<Vec<T>>) {
27 let (sender, receiver) = mpsc::channel(1);
28 let buffer = Arc::new(Mutex::new(Vec::new()));
29 let task = Arc::new(Mutex::new(None));
30 (
31 Self {
32 duration,
33 buffer,
34 task,
35 sender,
36 },
37 receiver,
38 )
39 }
40
41 async fn send_data(&self, data: Vec<T>, task_guard: &mut Option<JoinHandle<()>>) {
42 if let Err(e) = self.sender.send(data).await {
43 eprintln!("Failed to send data: {}", e);
44 return;
45 }
46 *task_guard = Some(self.schedule_task());
47 }
48
49 fn schedule_task(&self) -> JoinHandle<()> {
50 let this = self.clone();
51 tokio::spawn(async move {
52 sleep(this.duration).await;
53
54 let mut buffer_guard = this.buffer.lock().await;
55 let mut task_guard = this.task.lock().await;
56 if buffer_guard.is_empty() {
57 *task_guard = None;
58 return;
59 }
60
61 let data = buffer_guard.drain(..).collect::<Vec<_>>();
62 this.send_data(data, &mut task_guard).await;
63 })
64 }
65
66 pub async fn push(&self, item: T) {
67 let mut task_guard = self.task.lock().await;
68 if task_guard.is_some() {
69 let mut buffer_guard = self.buffer.lock().await;
70 buffer_guard.push(item);
71 } else {
72 self.send_data(vec![item], &mut task_guard).await;
73 }
74 }
75
76 pub async fn drain(&self) {
77 let mut buffer_guard = self.buffer.lock().await;
78 let mut task_guard = self.task.lock().await;
79 if buffer_guard.is_empty() {
80 return;
81 }
82
83 let data = buffer_guard.drain(..).collect::<Vec<_>>();
84 self.send_data(data, &mut task_guard).await;
85 }
86}
87
88pub fn throttle_stream<S, T>(
89 input_stream: S,
90 duration: Duration,
91) -> impl Stream<Item = Vec<T>>
92where
93 S: Stream<Item = T> + Send + 'static,
94 T: Send + 'static,
95{
96 let (throttle, mut receiver) = ThrottleStream::new(duration);
97 tokio::spawn(async move {
98 let mut input_stream = Box::pin(input_stream);
99 while let Some(item) = input_stream.next().await {
100 throttle.push(item).await;
101 }
102 });
103 async_stream::stream! {
104 while let Some(data) = receiver.recv().await {
105 yield data;
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[tokio::test]
115 async fn test_throttle_stream() {
116 let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
117
118 throttle.push(1).await;
119 assert_eq!(receiver.try_recv(), Ok(vec![1]));
121
122 throttle.push(2).await;
123 throttle.push(3).await;
124 assert!(receiver.try_recv().is_err());
126
127 sleep(Duration::from_millis(120)).await;
128 assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
130
131 throttle.push(4).await;
132 throttle.push(5).await;
133 assert!(receiver.try_recv().is_err());
135
136 sleep(Duration::from_millis(120)).await;
137 assert_eq!(receiver.recv().await, Some(vec![4, 5]));
139 }
140
141 #[tokio::test]
142 async fn test_throttle_stream_flush() {
143 let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
144
145 throttle.push(1).await;
146 assert_eq!(receiver.try_recv(), Ok(vec![1]));
148
149 throttle.push(2).await;
150 throttle.push(3).await;
151 assert!(receiver.try_recv().is_err());
153
154 throttle.drain().await;
155 assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
157
158 throttle.drain().await;
159 assert!(receiver.try_recv().is_err());
161
162 throttle.push(4).await;
163 throttle.push(5).await;
164 assert!(receiver.try_recv().is_err());
166
167 throttle.drain().await;
168 assert_eq!(receiver.try_recv(), Ok(vec![4, 5]));
170 }
171}