use std::time::Duration;
use std::sync::Arc;
use futures::Stream;
use futures::stream::StreamExt;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::sleep;
pub struct ThrottleStream<T> {
duration: Duration,
buffer: Arc<Mutex<Vec<T>>>,
task: Arc<Mutex<Option<JoinHandle<()>>>>,
sender: mpsc::Sender<Vec<T>>,
}
impl<T: Send + 'static> ThrottleStream<T> {
pub fn clone(&self) -> Self {
Self {
duration: self.duration,
buffer: self.buffer.clone(),
task: self.task.clone(),
sender: self.sender.clone(),
}
}
pub fn new(duration: Duration) -> (Self, mpsc::Receiver<Vec<T>>) {
let (sender, receiver) = mpsc::channel(1);
let buffer = Arc::new(Mutex::new(Vec::new()));
let task = Arc::new(Mutex::new(None));
(
Self {
duration,
buffer,
task,
sender,
},
receiver,
)
}
async fn send_data(&self, data: Vec<T>, task_guard: &mut Option<JoinHandle<()>>) {
if let Err(e) = self.sender.send(data).await {
eprintln!("Failed to send data: {}", e);
return;
}
*task_guard = Some(self.schedule_task());
}
fn schedule_task(&self) -> JoinHandle<()> {
let this = self.clone();
tokio::spawn(async move {
sleep(this.duration).await;
let mut buffer_guard = this.buffer.lock().await;
let mut task_guard = this.task.lock().await;
if buffer_guard.is_empty() {
*task_guard = None;
return;
}
let data = buffer_guard.drain(..).collect::<Vec<_>>();
this.send_data(data, &mut task_guard).await;
})
}
pub async fn push(&self, item: T) {
let mut task_guard = self.task.lock().await;
if task_guard.is_some() {
let mut buffer_guard = self.buffer.lock().await;
buffer_guard.push(item);
} else {
self.send_data(vec![item], &mut task_guard).await;
}
}
pub async fn drain(&self) {
let mut buffer_guard = self.buffer.lock().await;
let mut task_guard = self.task.lock().await;
if buffer_guard.is_empty() {
return;
}
let data = buffer_guard.drain(..).collect::<Vec<_>>();
self.send_data(data, &mut task_guard).await;
}
}
pub fn throttle_stream<S, T>(
input_stream: S,
duration: Duration,
) -> impl Stream<Item = Vec<T>>
where
S: Stream<Item = T> + Send + 'static,
T: Send + 'static,
{
let (throttle, mut receiver) = ThrottleStream::new(duration);
tokio::spawn(async move {
let mut input_stream = Box::pin(input_stream);
while let Some(item) = input_stream.next().await {
throttle.push(item).await;
}
});
async_stream::stream! {
while let Some(data) = receiver.recv().await {
yield data;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_throttle_stream() {
let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
throttle.push(1).await;
assert_eq!(receiver.try_recv(), Ok(vec![1]));
throttle.push(2).await;
throttle.push(3).await;
assert!(receiver.try_recv().is_err());
sleep(Duration::from_millis(120)).await;
assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
throttle.push(4).await;
throttle.push(5).await;
assert!(receiver.try_recv().is_err());
sleep(Duration::from_millis(120)).await;
assert_eq!(receiver.recv().await, Some(vec![4, 5]));
}
#[tokio::test]
async fn test_throttle_stream_flush() {
let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
throttle.push(1).await;
assert_eq!(receiver.try_recv(), Ok(vec![1]));
throttle.push(2).await;
throttle.push(3).await;
assert!(receiver.try_recv().is_err());
throttle.drain().await;
assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
throttle.drain().await;
assert!(receiver.try_recv().is_err());
throttle.push(4).await;
throttle.push(5).await;
assert!(receiver.try_recv().is_err());
throttle.drain().await;
assert_eq!(receiver.try_recv(), Ok(vec![4, 5]));
}
}