use std::time::Duration;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::sleep;
use std::sync::Arc;
pub struct ThrottleStream<T> {
duration: Duration,
buffer: Arc<Mutex<Vec<T>>>,
task: Arc<Mutex<Option<JoinHandle<()>>>>,
sender: mpsc::Sender<Vec<T>>,
}
impl<T: Send + 'static> ThrottleStream<T> {
pub fn new(duration: Duration) -> (Self, mpsc::Receiver<Vec<T>>) {
let (sender, receiver) = mpsc::channel(1);
let buffer = Arc::new(Mutex::new(Vec::new()));
let task = Arc::new(Mutex::new(None));
(
Self {
duration,
buffer,
task,
sender,
},
receiver,
)
}
pub fn clone(&self) -> Self {
Self {
duration: self.duration,
buffer: self.buffer.clone(),
task: self.task.clone(),
sender: self.sender.clone(),
}
}
fn schedule_task(&self) -> JoinHandle<()> {
let this = self.clone();
tokio::spawn(async move {
sleep(this.duration).await;
let mut buffer_guard = this.buffer.lock().await;
let mut task_guard = this.task.lock().await;
if buffer_guard.is_empty() {
*task_guard = None;
return;
}
let data = buffer_guard.drain(..).collect::<Vec<_>>();
if let Err(e) = this.sender.send(data).await {
eprintln!("Failed to send data: {}", e);
return;
}
*task_guard = Some(this.schedule_task());
})
}
pub async fn push(&self, item: T) {
let mut task_guard = self.task.lock().await;
if task_guard.is_some() {
let mut buffer_guard = self.buffer.lock().await;
buffer_guard.push(item);
} else {
let data = vec![item];
if let Err(e) = self.sender.send(data).await {
eprintln!("Failed to send data: {}", e);
return;
}
*task_guard = Some(self.schedule_task());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_throttle_stream() {
let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));
throttle.push(1).await;
assert_eq!(receiver.try_recv(), Ok(vec![1]));
throttle.push(2).await;
throttle.push(3).await;
assert!(receiver.try_recv().is_err());
sleep(Duration::from_millis(120)).await;
assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));
throttle.push(4).await;
throttle.push(5).await;
assert!(receiver.try_recv().is_err());
sleep(Duration::from_millis(120)).await;
assert_eq!(receiver.recv().await, Some(vec![4, 5]));
}
}