synaptic_models/
token_bucket.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapticError};
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8pub struct TokenBucket {
14 capacity: f64,
15 tokens: Mutex<f64>,
16 refill_rate: f64,
17 last_refill: Mutex<Instant>,
18}
19
20impl TokenBucket {
21 pub fn new(capacity: f64, refill_rate: f64) -> Self {
26 Self {
27 capacity,
28 tokens: Mutex::new(capacity),
29 refill_rate,
30 last_refill: Mutex::new(Instant::now()),
31 }
32 }
33
34 pub async fn acquire(&self) {
36 loop {
37 self.refill().await;
38
39 let mut tokens = self.tokens.lock().await;
40 if *tokens >= 1.0 {
41 *tokens -= 1.0;
42 return;
43 }
44 drop(tokens);
45
46 let wait = std::time::Duration::from_secs_f64(1.0 / self.refill_rate);
49 tokio::time::sleep(wait).await;
50 }
51 }
52
53 async fn refill(&self) {
54 let now = Instant::now();
55 let mut last_refill = self.last_refill.lock().await;
56 let elapsed = now.duration_since(*last_refill);
57 let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
58
59 if new_tokens > 0.0 {
60 let mut tokens = self.tokens.lock().await;
61 *tokens = (*tokens + new_tokens).min(self.capacity);
62 *last_refill = now;
63 }
64 }
65}
66
67pub struct TokenBucketChatModel {
72 inner: Arc<dyn ChatModel>,
73 bucket: Arc<TokenBucket>,
74}
75
76impl TokenBucketChatModel {
77 pub fn new(inner: Arc<dyn ChatModel>, capacity: f64, refill_rate: f64) -> Self {
83 Self {
84 inner,
85 bucket: Arc::new(TokenBucket::new(capacity, refill_rate)),
86 }
87 }
88}
89
90#[async_trait]
91impl ChatModel for TokenBucketChatModel {
92 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
93 self.bucket.acquire().await;
94 self.inner.chat(request).await
95 }
96
97 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
98 let inner = self.inner.clone();
99 let bucket = self.bucket.clone();
100
101 Box::pin(async_stream::stream! {
102 bucket.acquire().await;
103
104 use futures::StreamExt;
105 let mut stream = inner.stream_chat(request);
106 while let Some(result) = stream.next().await {
107 yield result;
108 }
109 })
110 }
111}