throttle_stream/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use std::time::Duration;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::sleep;
use std::sync::Arc;

pub struct ThrottleStream<T> {
    duration: Duration,
    buffer: Arc<Mutex<Vec<T>>>,
    task: Arc<Mutex<Option<JoinHandle<()>>>>,
    sender: mpsc::Sender<Vec<T>>,
}

impl<T: Send + 'static> ThrottleStream<T> {
    pub fn new(duration: Duration) -> (Self, mpsc::Receiver<Vec<T>>) {
        let (sender, receiver) = mpsc::channel(1);
        let buffer = Arc::new(Mutex::new(Vec::new()));
        let task = Arc::new(Mutex::new(None));
        (
            Self {
                duration,
                buffer,
                task,
                sender,
            },
            receiver,
        )
    }

    pub fn clone(&self) -> Self {
        Self {
            duration: self.duration,
            buffer: self.buffer.clone(),
            task: self.task.clone(),
            sender: self.sender.clone(),
        }
    }

    fn schedule_task(&self) -> JoinHandle<()> {
        let this = self.clone();
        tokio::spawn(async move {
            sleep(this.duration).await;

            let mut buffer_guard = this.buffer.lock().await;
            let mut task_guard = this.task.lock().await;
            if buffer_guard.is_empty() {
                *task_guard = None;
                return;
            }

            let data = buffer_guard.drain(..).collect::<Vec<_>>();
            if let Err(e) = this.sender.send(data).await {
                eprintln!("Failed to send data: {}", e);
                return;
            }
            *task_guard = Some(this.schedule_task());
        })
    }

    pub async fn push(&self, item: T) {
        let mut task_guard = self.task.lock().await;
        if task_guard.is_some() {
            let mut buffer_guard = self.buffer.lock().await;
            buffer_guard.push(item);
        } else {
            let data = vec![item];
            if let Err(e) = self.sender.send(data).await {
                eprintln!("Failed to send data: {}", e);
                return;
            }
            *task_guard = Some(self.schedule_task());
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_throttle_stream() {
        let (throttle, mut receiver) = ThrottleStream::new(Duration::from_millis(100));

        throttle.push(1).await;
        // time: 0ms
        assert_eq!(receiver.try_recv(), Ok(vec![1]));

        throttle.push(2).await;
        throttle.push(3).await;
        // time: 0ms
        assert!(receiver.try_recv().is_err());

        sleep(Duration::from_millis(120)).await;
        // time: 120ms
        assert_eq!(receiver.try_recv(), Ok(vec![2, 3]));

        throttle.push(4).await;
        throttle.push(5).await;
        // time: 120ms
        assert!(receiver.try_recv().is_err());

        sleep(Duration::from_millis(120)).await;
        // time: 240ms
        assert_eq!(receiver.recv().await, Some(vec![4, 5]));
    }
}