Skip to main content

wesichain_core/
rate_limiter.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use futures::stream::BoxStream;
5use futures::StreamExt as _;
6use tokio::sync::Mutex;
7
8use crate::{Runnable, StreamEvent, WesichainError};
9
10pub struct RateLimited<R> {
11    inner: R,
12    interval: Duration,
13    last_call: Arc<Mutex<Option<Instant>>>,
14}
15
16impl<R> RateLimited<R> {
17    pub fn new(inner: R, requests_per_minute: u32) -> Self {
18        let interval = if requests_per_minute == 0 {
19            Duration::from_secs(u64::MAX / 2)
20        } else {
21            Duration::from_secs(60) / requests_per_minute
22        };
23        Self {
24            inner,
25            interval,
26            last_call: Arc::new(Mutex::new(None)),
27        }
28    }
29
30    async fn throttle(&self) {
31        let mut last = self.last_call.lock().await;
32        if let Some(t) = *last {
33            let elapsed = t.elapsed();
34            if elapsed < self.interval {
35                tokio::time::sleep(self.interval - elapsed).await;
36            }
37        }
38        *last = Some(Instant::now());
39    }
40}
41
42#[async_trait::async_trait]
43impl<Input, Output, R> Runnable<Input, Output> for RateLimited<R>
44where
45    Input: Send + Clone + 'static,
46    Output: Send + 'static,
47    R: Runnable<Input, Output> + Send + Sync,
48{
49    async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
50        self.throttle().await;
51        self.inner.invoke(input).await
52    }
53
54    fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
55        let inner = &self.inner;
56        async_stream::stream! {
57            self.throttle().await;
58            let mut s = inner.stream(input);
59            while let Some(event) = s.next().await {
60                yield event;
61            }
62        }
63        .boxed()
64    }
65}