synaptic_models/
rate_limit.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapticError};
5use tokio::sync::Semaphore;
6
7pub struct RateLimitedChatModel {
8 inner: Arc<dyn ChatModel>,
9 semaphore: Arc<Semaphore>,
10}
11
12impl RateLimitedChatModel {
13 pub fn new(inner: Arc<dyn ChatModel>, max_concurrent: usize) -> Self {
14 Self {
15 inner,
16 semaphore: Arc::new(Semaphore::new(max_concurrent)),
17 }
18 }
19}
20
21#[async_trait]
22impl ChatModel for RateLimitedChatModel {
23 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
24 let _permit = self
25 .semaphore
26 .acquire()
27 .await
28 .map_err(|e| SynapticError::Model(format!("semaphore error: {e}")))?;
29 self.inner.chat(request).await
30 }
31
32 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
33 let inner = self.inner.clone();
34 let semaphore = self.semaphore.clone();
35
36 Box::pin(async_stream::stream! {
37 let _permit = match semaphore.acquire_owned().await {
38 Ok(p) => p,
39 Err(e) => {
40 yield Err(SynapticError::Model(format!("semaphore error: {e}")));
41 return;
42 }
43 };
44
45 use futures::StreamExt;
46 let mut stream = inner.stream_chat(request);
47 while let Some(result) = stream.next().await {
48 yield result;
49 }
50 })
51 }
52}