wesichain_core/
rate_limiter.rs1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use futures::stream::BoxStream;
5use futures::StreamExt as _;
6use tokio::sync::Mutex;
7
8use crate::{Runnable, StreamEvent, WesichainError};
9
10pub struct RateLimited<R> {
11 inner: R,
12 interval: Duration,
13 last_call: Arc<Mutex<Option<Instant>>>,
14}
15
16impl<R> RateLimited<R> {
17 pub fn new(inner: R, requests_per_minute: u32) -> Self {
18 let interval = if requests_per_minute == 0 {
19 Duration::from_secs(u64::MAX / 2)
20 } else {
21 Duration::from_secs(60) / requests_per_minute
22 };
23 Self {
24 inner,
25 interval,
26 last_call: Arc::new(Mutex::new(None)),
27 }
28 }
29
30 async fn throttle(&self) {
31 let mut last = self.last_call.lock().await;
32 if let Some(t) = *last {
33 let elapsed = t.elapsed();
34 if elapsed < self.interval {
35 tokio::time::sleep(self.interval - elapsed).await;
36 }
37 }
38 *last = Some(Instant::now());
39 }
40}
41
42#[async_trait::async_trait]
43impl<Input, Output, R> Runnable<Input, Output> for RateLimited<R>
44where
45 Input: Send + Clone + 'static,
46 Output: Send + 'static,
47 R: Runnable<Input, Output> + Send + Sync,
48{
49 async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
50 self.throttle().await;
51 self.inner.invoke(input).await
52 }
53
54 fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
55 let inner = &self.inner;
56 async_stream::stream! {
57 self.throttle().await;
58 let mut s = inner.stream(input);
59 while let Some(event) = s.next().await {
60 yield event;
61 }
62 }
63 .boxed()
64 }
65}