1use anyhow::Result;
4use hashbrown::HashMap;
5use parking_lot::Mutex;
6use serde_json::Value;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::{Duration, Instant};
10use tokio::sync::{Notify, RwLock, Semaphore, mpsc};
11use tracing::debug;
12
13use crate::llm::types::LLMError;
14
15#[derive(Debug, Clone)]
17pub struct OptimizedRequest {
18 pub model: String,
19 pub messages: Vec<Value>,
20 pub temperature: Option<f32>,
21 pub max_tokens: Option<u32>,
22}
23
24#[derive(Debug, Clone)]
26pub struct OptimizedResponse {
27 pub content: String,
28 pub usage: Option<Value>,
29}
30
31pub struct ConnectionPool {
33 clients: Arc<RwLock<Vec<reqwest::Client>>>,
35
36 max_size: usize,
38
39 active_connections: Arc<Semaphore>,
41}
42
43pub struct RequestBatcher {
45 pending_requests: Arc<RwLock<HashMap<String, Vec<BatchedRequest>>>>,
47
48 work_notify: Arc<Notify>,
50
51 batch_interval: Duration,
53
54 max_batch_size: usize,
56
57 processing_started: AtomicBool,
59
60 shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
62
63 processing_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
65}
66
67#[derive(Debug)]
69pub struct BatchedRequest {
70 pub id: String,
71 pub request: OptimizedRequest,
72 pub response_tx: tokio::sync::oneshot::Sender<Result<OptimizedResponse, LLMError>>,
73 pub submitted_at: Instant,
74}
75
76pub struct OptimizedLLMClient {
78 connection_pool: Arc<ConnectionPool>,
80
81 request_batcher: Arc<RequestBatcher>,
83
84 response_cache: Arc<RwLock<lru::LruCache<String, CachedResponse>>>,
86
87 rate_limiter: Arc<RateLimiter>,
89
90 metrics: Arc<RwLock<ClientMetrics>>,
92}
93
94#[derive(Debug, Clone)]
96pub struct CachedResponse {
97 pub response: OptimizedResponse,
98 pub cached_at: Instant,
99 pub ttl: Duration,
100}
101
102pub struct RateLimiter {
104 token_bucket: Arc<RwLock<TokenBucket>>,
106}
107
108#[derive(Debug)]
110pub struct TokenBucket {
111 pub tokens: f64,
112 pub capacity: f64,
113 pub refill_rate: f64,
114 pub last_refill: Instant,
115}
116
117#[derive(Debug, Default, Clone)]
119pub struct ClientMetrics {
120 pub total_requests: u64,
121 pub cache_hits: u64,
122 pub batched_requests: u64,
123 pub avg_response_time_ms: f64,
124 pub connection_pool_utilization: f64,
125 pub rate_limit_hits: u64,
126}
127
128impl ConnectionPool {
129 pub fn new(max_size: usize) -> Self {
130 let clients = Vec::with_capacity(max_size);
131
132 Self {
133 clients: Arc::new(RwLock::new(clients)),
134 max_size,
135 active_connections: Arc::new(Semaphore::new(max_size)),
136 }
137 }
138
139 pub async fn get_client(&self) -> Result<reqwest::Client> {
141 {
143 let mut clients = self.clients.write().await;
144 if let Some(client) = clients.pop() {
145 return Ok(client);
146 }
147 }
148
149 let client = reqwest::Client::builder()
151 .pool_max_idle_per_host(10)
152 .pool_idle_timeout(Duration::from_secs(30))
153 .timeout(Duration::from_secs(60))
154 .tcp_keepalive(Duration::from_secs(60))
155 .http2_prior_knowledge()
156 .build()?;
157
158 Ok(client)
159 }
160
161 pub async fn return_client(&self, client: reqwest::Client) {
163 let mut clients = self.clients.write().await;
164 if clients.len() < self.max_size {
165 clients.push(client);
166 }
167 }
168
169 pub async fn utilization(&self) -> f64 {
171 let available = self.active_connections.available_permits();
172 let total = self.max_size;
173 (total - available) as f64 / total as f64
174 }
175}
176
177impl RequestBatcher {
178 pub fn new(batch_interval: Duration, max_batch_size: usize) -> Self {
179 Self {
180 pending_requests: Arc::new(RwLock::new(HashMap::new())),
181 work_notify: Arc::new(Notify::new()),
182 batch_interval,
183 max_batch_size,
184 processing_started: AtomicBool::new(false),
185 shutdown_tx: Mutex::new(None),
186 processing_task: Mutex::new(None),
187 }
188 }
189
190 pub async fn add_request(&self, request: BatchedRequest) -> Result<()> {
192 let batch_key = self.generate_batch_key(&request.request);
193
194 let mut pending = self.pending_requests.write().await;
195 let batch = pending.entry(batch_key).or_insert_with(Vec::new);
196
197 batch.push(request);
198
199 if batch.len() >= self.max_batch_size {
201 let batch_requests = std::mem::take(batch);
203 drop(pending);
204
205 tokio::spawn(async move {
206 Self::process_batch(batch_requests).await;
207 });
208 } else {
209 drop(pending);
210 self.work_notify.notify_one();
211 }
212
213 Ok(())
214 }
215
216 fn generate_batch_key(&self, request: &OptimizedRequest) -> String {
218 use std::collections::hash_map::DefaultHasher;
219 use std::hash::{Hash, Hasher};
220
221 let mut hasher = DefaultHasher::new();
222
223 request.model.hash(&mut hasher);
224 request.temperature.map(f32::to_bits).hash(&mut hasher);
225 request.max_tokens.hash(&mut hasher);
226
227 format!("{:x}", hasher.finish())
228 }
229
230 async fn process_batch(requests: Vec<BatchedRequest>) {
232 debug!("Processing batch of {} requests", requests.len());
233
234 let mut tasks = tokio::task::JoinSet::new();
235 for request in requests {
236 tasks.spawn(async move {
237 let result = Self::execute_single_request(request.request).await;
238 let _ = request.response_tx.send(result);
239 });
240 }
241
242 while let Some(result) = tasks.join_next().await {
243 if let Err(error) = result {
244 debug!(?error, "batched request task failed");
245 }
246 }
247 }
248
249 async fn execute_single_request(
251 _request: OptimizedRequest,
252 ) -> Result<OptimizedResponse, LLMError> {
253 tokio::time::sleep(Duration::from_millis(100)).await;
255
256 Ok(OptimizedResponse {
257 content: "Batched response".to_string(),
258 usage: None,
259 })
260 }
261
262 pub async fn start_processing(&self) {
264 if self.processing_started.swap(true, Ordering::SeqCst) {
265 return;
266 }
267
268 let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
269 *self.shutdown_tx.lock() = Some(shutdown_tx);
270
271 let pending_requests = Arc::clone(&self.pending_requests);
272 let work_notify = Arc::clone(&self.work_notify);
273 let batch_interval = self.batch_interval;
274
275 let processing_task = tokio::spawn(async move {
276 loop {
277 tokio::select! {
278 _ = shutdown_rx.recv() => {
279 debug!("LLM request batch processing shutdown requested");
280 break;
281 }
282 _ = work_notify.notified() => {}
283 }
284
285 let flush_deadline = tokio::time::Instant::now() + batch_interval;
286 let sleep_until_flush = tokio::time::sleep_until(flush_deadline);
287 tokio::pin!(sleep_until_flush);
288
289 loop {
290 tokio::select! {
291 _ = shutdown_rx.recv() => {
292 debug!("LLM request batch processing shutdown requested");
293 return;
294 }
295 _ = &mut sleep_until_flush => {
296 let batches_to_process = Self::take_pending_batches(&pending_requests).await;
297 for batch in batches_to_process {
298 tokio::spawn(async move {
299 Self::process_batch(batch).await;
300 });
301 }
302 break;
303 }
304 _ = work_notify.notified() => {}
305 }
306 }
307 }
308 });
309 *self.processing_task.lock() = Some(processing_task);
310 }
311
312 async fn take_pending_batches(
313 pending_requests: &Arc<RwLock<HashMap<String, Vec<BatchedRequest>>>>,
314 ) -> Vec<Vec<BatchedRequest>> {
315 let mut pending = pending_requests.write().await;
316 let mut batches = Vec::new();
317
318 for requests in pending.values_mut() {
319 if !requests.is_empty() {
320 batches.push(std::mem::take(requests));
321 }
322 }
323
324 pending.retain(|_, requests| !requests.is_empty());
325 batches
326 }
327
328 pub async fn shutdown_processing(&self) {
329 let shutdown_tx = { self.shutdown_tx.lock().take() };
330 if let Some(tx) = shutdown_tx {
331 let _ = tx.send(()).await;
332 }
333
334 let handle = { self.processing_task.lock().take() };
335 if let Some(handle) = handle {
336 let _ = handle.await;
337 }
338
339 self.processing_started.store(false, Ordering::SeqCst);
340 }
341}
342
343impl Drop for RequestBatcher {
344 fn drop(&mut self) {
345 if let Some(handle) = self.processing_task.lock().take() {
346 handle.abort();
347 }
348 self.shutdown_tx.lock().take();
349 }
350}
351
352impl RateLimiter {
353 pub fn new(requests_per_second: f64, burst_capacity: usize) -> Self {
354 let refill_rate = if requests_per_second.is_finite() && requests_per_second > 0.0 {
355 requests_per_second
356 } else {
357 1.0
358 };
359 let burst_capacity = burst_capacity.max(1);
360
361 Self {
362 token_bucket: Arc::new(RwLock::new(TokenBucket {
363 tokens: burst_capacity as f64,
364 capacity: burst_capacity as f64,
365 refill_rate,
366 last_refill: Instant::now(),
367 })),
368 }
369 }
370
371 pub async fn acquire(&self) -> Result<()> {
373 loop {
374 let wait_time = {
375 let mut bucket = self.token_bucket.write().await;
376 Self::refill_tokens(&mut bucket);
377
378 if bucket.tokens >= 1.0 {
379 bucket.tokens -= 1.0;
380 return Ok(());
381 }
382
383 let wait_secs = (1.0 - bucket.tokens) / bucket.refill_rate;
384 Duration::try_from_secs_f64(wait_secs).unwrap_or(Duration::from_secs(60))
385 };
386
387 tokio::time::sleep(wait_time).await;
388 }
389 }
390
391 fn refill_tokens(bucket: &mut TokenBucket) {
393 let now = Instant::now();
394 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
395
396 let tokens_to_add = elapsed * bucket.refill_rate;
397 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.capacity);
398 bucket.last_refill = now;
399 }
400}
401
402impl OptimizedLLMClient {
403 pub fn new(
404 pool_size: usize,
405 cache_size: usize,
406 requests_per_second: f64,
407 burst_capacity: usize,
408 ) -> Self {
409 Self {
410 connection_pool: Arc::new(ConnectionPool::new(pool_size)),
411 request_batcher: Arc::new(RequestBatcher::new(Duration::from_millis(100), 10)),
412 response_cache: Arc::new(RwLock::new(lru::LruCache::new(
413 std::num::NonZeroUsize::new(cache_size).unwrap_or(std::num::NonZeroUsize::MIN),
414 ))),
415 rate_limiter: Arc::new(RateLimiter::new(requests_per_second, burst_capacity)),
416 metrics: Arc::new(RwLock::new(ClientMetrics::default())),
417 }
418 }
419
420 pub async fn chat_optimized(
422 &self,
423 request: OptimizedRequest,
424 ) -> Result<OptimizedResponse, LLMError> {
425 let start_time = Instant::now();
426
427 let cache_key = self.generate_cache_key(&request);
429
430 let cached_response = {
432 let cache = self.response_cache.read().await;
433 cache
434 .peek(&cache_key)
435 .filter(|cached| cached.cached_at.elapsed() < cached.ttl)
436 .map(|cached| cached.response.clone())
437 };
438 if let Some(response) = cached_response {
439 self.metrics.write().await.cache_hits += 1;
440 return Ok(response);
441 }
442
443 self.request_batcher.start_processing().await;
444
445 self.rate_limiter
447 .acquire()
448 .await
449 .map_err(|_e| LLMError::RateLimit { metadata: None })?;
450
451 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
453 let batched_request = BatchedRequest {
454 id: uuid::Uuid::new_v4().to_string(),
455 request,
456 response_tx,
457 submitted_at: start_time,
458 };
459
460 self.request_batcher
462 .add_request(batched_request)
463 .await
464 .map_err(|e| LLMError::InvalidRequest {
465 message: e.to_string(),
466 metadata: None,
467 })?;
468
469 let response = response_rx.await.map_err(|e| LLMError::InvalidRequest {
471 message: e.to_string(),
472 metadata: None,
473 })??;
474
475 let cached_response = CachedResponse {
477 response: response.clone(),
478 cached_at: Instant::now(),
479 ttl: Duration::from_secs(300), };
481
482 self.response_cache
483 .write()
484 .await
485 .put(cache_key, cached_response);
486
487 let execution_time = start_time.elapsed();
489 let mut metrics = self.metrics.write().await;
490 metrics.total_requests += 1;
491
492 let alpha = 0.1;
494 metrics.avg_response_time_ms = alpha * execution_time.as_millis() as f64
495 + (1.0 - alpha) * metrics.avg_response_time_ms;
496
497 Ok(response)
498 }
499
500 fn generate_cache_key(&self, request: &OptimizedRequest) -> String {
502 use std::collections::hash_map::DefaultHasher;
503 use std::hash::{Hash, Hasher};
504
505 let mut hasher = DefaultHasher::new();
506 request.model.hash(&mut hasher);
507 request.temperature.map(f32::to_bits).hash(&mut hasher);
508 request.max_tokens.hash(&mut hasher);
509
510 for message in &request.messages {
511 message.to_string().hash(&mut hasher);
512 }
513
514 format!("{:x}", hasher.finish())
515 }
516
517 pub async fn start(&self) -> Result<()> {
519 self.request_batcher.start_processing().await;
520 Ok(())
521 }
522
523 pub async fn shutdown(&self) -> Result<()> {
524 self.request_batcher.shutdown_processing().await;
525 Ok(())
526 }
527
528 pub async fn get_metrics(&self) -> ClientMetrics {
530 let mut metrics = self.metrics.read().await.clone();
531 metrics.connection_pool_utilization = self.connection_pool.utilization().await;
532 metrics
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_cache_key_includes_generation_settings() {
542 let client = OptimizedLLMClient::new(1, 16, 10.0, 1);
543 let base_request = OptimizedRequest {
544 model: "model-a".to_string(),
545 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
546 temperature: Some(0.2),
547 max_tokens: Some(128),
548 };
549 let different_temperature = OptimizedRequest {
550 temperature: Some(0.8),
551 ..base_request.clone()
552 };
553 let different_max_tokens = OptimizedRequest {
554 max_tokens: Some(256),
555 ..base_request.clone()
556 };
557
558 assert_ne!(
559 client.generate_cache_key(&base_request),
560 client.generate_cache_key(&different_temperature)
561 );
562 assert_ne!(
563 client.generate_cache_key(&base_request),
564 client.generate_cache_key(&different_max_tokens)
565 );
566 }
567
568 #[test]
569 fn test_batch_key_includes_generation_settings() {
570 let batcher = RequestBatcher::new(Duration::from_millis(100), 10);
571 let base_request = OptimizedRequest {
572 model: "model-a".to_string(),
573 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
574 temperature: Some(0.2),
575 max_tokens: Some(128),
576 };
577 let different_request = OptimizedRequest {
578 temperature: Some(0.8),
579 ..base_request.clone()
580 };
581
582 assert_ne!(
583 batcher.generate_batch_key(&base_request),
584 batcher.generate_batch_key(&different_request)
585 );
586 }
587
588 #[tokio::test]
589 async fn test_chat_optimized_starts_batch_processing_automatically() {
590 let client = OptimizedLLMClient::new(1, 16, 10.0, 1);
591 let response = tokio::time::timeout(
592 Duration::from_secs(1),
593 client.chat_optimized(OptimizedRequest {
594 model: "model-a".to_string(),
595 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
596 temperature: Some(0.2),
597 max_tokens: Some(128),
598 }),
599 )
600 .await
601 .expect("request should complete without explicit start")
602 .expect("request should succeed");
603
604 assert_eq!(response.content, "Batched response");
605 }
606
607 #[tokio::test]
608 async fn rate_limiter_zero_burst_capacity_still_allows_progress() {
609 let limiter = RateLimiter::new(10.0, 0);
610
611 tokio::time::timeout(Duration::from_millis(100), limiter.acquire())
612 .await
613 .expect("rate limiter should not stall with zero configured burst")
614 .expect("rate limiter acquire should succeed");
615
616 assert_eq!(limiter.token_bucket.read().await.capacity, 1.0);
617 }
618}