Skip to main content

vtcode_core/llm/
optimized_client.rs

1//! Optimized LLM client with connection pooling and request batching
2
3use anyhow::Result;
4use hashbrown::HashMap;
5use parking_lot::Mutex;
6use serde_json::Value;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::time::{Duration, Instant};
10use tokio::sync::{Notify, RwLock, Semaphore, mpsc};
11use tracing::debug;
12
13use crate::llm::types::LLMError;
14
15/// Simplified request structure for optimization
16#[derive(Debug, Clone)]
17pub struct OptimizedRequest {
18    pub model: String,
19    pub messages: Vec<Value>,
20    pub temperature: Option<f32>,
21    pub max_tokens: Option<u32>,
22}
23
24/// Simplified response structure
25#[derive(Debug, Clone)]
26pub struct OptimizedResponse {
27    pub content: String,
28    pub usage: Option<Value>,
29}
30
31/// Connection pool for HTTP clients
32pub struct ConnectionPool {
33    /// Pool of reusable HTTP clients
34    clients: Arc<RwLock<Vec<reqwest::Client>>>,
35
36    /// Maximum pool size
37    max_size: usize,
38
39    /// Current pool utilization
40    active_connections: Arc<Semaphore>,
41}
42
43/// Request batching manager for similar requests
44pub struct RequestBatcher {
45    /// Pending requests waiting to be batched
46    pending_requests: Arc<RwLock<HashMap<String, Vec<BatchedRequest>>>>,
47
48    /// Wakes the processing loop when new partial batches arrive.
49    work_notify: Arc<Notify>,
50
51    /// Batch processing interval
52    batch_interval: Duration,
53
54    /// Maximum batch size
55    max_batch_size: usize,
56
57    /// Guards against spawning duplicate processing loops.
58    processing_started: AtomicBool,
59
60    /// Shutdown signal sender for the background processing loop.
61    shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
62
63    /// Handle for the background processing loop task.
64    processing_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
65}
66
67/// A request that can be batched with similar requests
68#[derive(Debug)]
69pub struct BatchedRequest {
70    pub id: String,
71    pub request: OptimizedRequest,
72    pub response_tx: tokio::sync::oneshot::Sender<Result<OptimizedResponse, LLMError>>,
73    pub submitted_at: Instant,
74}
75
76/// Optimized LLM client with advanced performance features
77pub struct OptimizedLLMClient {
78    /// Connection pool for HTTP requests
79    connection_pool: Arc<ConnectionPool>,
80
81    /// Request batcher for similar requests
82    request_batcher: Arc<RequestBatcher>,
83
84    /// Response cache for identical requests
85    response_cache: Arc<RwLock<lru::LruCache<String, CachedResponse>>>,
86
87    /// Rate limiter for API calls
88    rate_limiter: Arc<RateLimiter>,
89
90    /// Performance metrics
91    metrics: Arc<RwLock<ClientMetrics>>,
92}
93
94/// Cached response with TTL
95#[derive(Debug, Clone)]
96pub struct CachedResponse {
97    pub response: OptimizedResponse,
98    pub cached_at: Instant,
99    pub ttl: Duration,
100}
101
102/// Rate limiter for API requests
103pub struct RateLimiter {
104    /// Token bucket for burst handling
105    token_bucket: Arc<RwLock<TokenBucket>>,
106}
107
108/// Token bucket for rate limiting
109#[derive(Debug)]
110pub struct TokenBucket {
111    pub tokens: f64,
112    pub capacity: f64,
113    pub refill_rate: f64,
114    pub last_refill: Instant,
115}
116
117/// Client performance metrics
118#[derive(Debug, Default, Clone)]
119pub struct ClientMetrics {
120    pub total_requests: u64,
121    pub cache_hits: u64,
122    pub batched_requests: u64,
123    pub avg_response_time_ms: f64,
124    pub connection_pool_utilization: f64,
125    pub rate_limit_hits: u64,
126}
127
128impl ConnectionPool {
129    pub fn new(max_size: usize) -> Self {
130        let clients = Vec::with_capacity(max_size);
131
132        Self {
133            clients: Arc::new(RwLock::new(clients)),
134            max_size,
135            active_connections: Arc::new(Semaphore::new(max_size)),
136        }
137    }
138
139    /// Get a client from the pool or create a new one
140    pub async fn get_client(&self) -> Result<reqwest::Client> {
141        // Try to get from pool first
142        {
143            let mut clients = self.clients.write().await;
144            if let Some(client) = clients.pop() {
145                return Ok(client);
146            }
147        }
148
149        // Create new client with optimized settings
150        let client = reqwest::Client::builder()
151            .pool_max_idle_per_host(10)
152            .pool_idle_timeout(Duration::from_secs(30))
153            .timeout(Duration::from_secs(60))
154            .tcp_keepalive(Duration::from_secs(60))
155            .http2_prior_knowledge()
156            .build()?;
157
158        Ok(client)
159    }
160
161    /// Return a client to the pool
162    pub async fn return_client(&self, client: reqwest::Client) {
163        let mut clients = self.clients.write().await;
164        if clients.len() < self.max_size {
165            clients.push(client);
166        }
167    }
168
169    /// Get current pool utilization
170    pub async fn utilization(&self) -> f64 {
171        let available = self.active_connections.available_permits();
172        let total = self.max_size;
173        (total - available) as f64 / total as f64
174    }
175}
176
177impl RequestBatcher {
178    pub fn new(batch_interval: Duration, max_batch_size: usize) -> Self {
179        Self {
180            pending_requests: Arc::new(RwLock::new(HashMap::new())),
181            work_notify: Arc::new(Notify::new()),
182            batch_interval,
183            max_batch_size,
184            processing_started: AtomicBool::new(false),
185            shutdown_tx: Mutex::new(None),
186            processing_task: Mutex::new(None),
187        }
188    }
189
190    /// Add request to batch queue
191    pub async fn add_request(&self, request: BatchedRequest) -> Result<()> {
192        let batch_key = self.generate_batch_key(&request.request);
193
194        let mut pending = self.pending_requests.write().await;
195        let batch = pending.entry(batch_key).or_insert_with(Vec::new);
196
197        batch.push(request);
198
199        // Trigger immediate processing if batch is full
200        if batch.len() >= self.max_batch_size {
201            // Process batch immediately
202            let batch_requests = std::mem::take(batch);
203            drop(pending);
204
205            tokio::spawn(async move {
206                Self::process_batch(batch_requests).await;
207            });
208        } else {
209            drop(pending);
210            self.work_notify.notify_one();
211        }
212
213        Ok(())
214    }
215
216    /// Generate batch key for grouping similar requests
217    fn generate_batch_key(&self, request: &OptimizedRequest) -> String {
218        use std::collections::hash_map::DefaultHasher;
219        use std::hash::{Hash, Hasher};
220
221        let mut hasher = DefaultHasher::new();
222
223        request.model.hash(&mut hasher);
224        request.temperature.map(f32::to_bits).hash(&mut hasher);
225        request.max_tokens.hash(&mut hasher);
226
227        format!("{:x}", hasher.finish())
228    }
229
230    /// Process a batch of similar requests
231    async fn process_batch(requests: Vec<BatchedRequest>) {
232        debug!("Processing batch of {} requests", requests.len());
233
234        let mut tasks = tokio::task::JoinSet::new();
235        for request in requests {
236            tasks.spawn(async move {
237                let result = Self::execute_single_request(request.request).await;
238                let _ = request.response_tx.send(result);
239            });
240        }
241
242        while let Some(result) = tasks.join_next().await {
243            if let Err(error) = result {
244                debug!(?error, "batched request task failed");
245            }
246        }
247    }
248
249    /// Execute a single request (placeholder)
250    async fn execute_single_request(
251        _request: OptimizedRequest,
252    ) -> Result<OptimizedResponse, LLMError> {
253        // Placeholder implementation
254        tokio::time::sleep(Duration::from_millis(100)).await;
255
256        Ok(OptimizedResponse {
257            content: "Batched response".to_string(),
258            usage: None,
259        })
260    }
261
262    /// Start batch processing loop
263    pub async fn start_processing(&self) {
264        if self.processing_started.swap(true, Ordering::SeqCst) {
265            return;
266        }
267
268        let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
269        *self.shutdown_tx.lock() = Some(shutdown_tx);
270
271        let pending_requests = Arc::clone(&self.pending_requests);
272        let work_notify = Arc::clone(&self.work_notify);
273        let batch_interval = self.batch_interval;
274
275        let processing_task = tokio::spawn(async move {
276            loop {
277                tokio::select! {
278                    _ = shutdown_rx.recv() => {
279                        debug!("LLM request batch processing shutdown requested");
280                        break;
281                    }
282                    _ = work_notify.notified() => {}
283                }
284
285                let flush_deadline = tokio::time::Instant::now() + batch_interval;
286                let sleep_until_flush = tokio::time::sleep_until(flush_deadline);
287                tokio::pin!(sleep_until_flush);
288
289                loop {
290                    tokio::select! {
291                        _ = shutdown_rx.recv() => {
292                            debug!("LLM request batch processing shutdown requested");
293                            return;
294                        }
295                        _ = &mut sleep_until_flush => {
296                            let batches_to_process = Self::take_pending_batches(&pending_requests).await;
297                            for batch in batches_to_process {
298                                tokio::spawn(async move {
299                                    Self::process_batch(batch).await;
300                                });
301                            }
302                            break;
303                        }
304                        _ = work_notify.notified() => {}
305                    }
306                }
307            }
308        });
309        *self.processing_task.lock() = Some(processing_task);
310    }
311
312    async fn take_pending_batches(
313        pending_requests: &Arc<RwLock<HashMap<String, Vec<BatchedRequest>>>>,
314    ) -> Vec<Vec<BatchedRequest>> {
315        let mut pending = pending_requests.write().await;
316        let mut batches = Vec::new();
317
318        for requests in pending.values_mut() {
319            if !requests.is_empty() {
320                batches.push(std::mem::take(requests));
321            }
322        }
323
324        pending.retain(|_, requests| !requests.is_empty());
325        batches
326    }
327
328    pub async fn shutdown_processing(&self) {
329        let shutdown_tx = { self.shutdown_tx.lock().take() };
330        if let Some(tx) = shutdown_tx {
331            let _ = tx.send(()).await;
332        }
333
334        let handle = { self.processing_task.lock().take() };
335        if let Some(handle) = handle {
336            let _ = handle.await;
337        }
338
339        self.processing_started.store(false, Ordering::SeqCst);
340    }
341}
342
343impl Drop for RequestBatcher {
344    fn drop(&mut self) {
345        if let Some(handle) = self.processing_task.lock().take() {
346            handle.abort();
347        }
348        self.shutdown_tx.lock().take();
349    }
350}
351
352impl RateLimiter {
353    pub fn new(requests_per_second: f64, burst_capacity: usize) -> Self {
354        let refill_rate = if requests_per_second.is_finite() && requests_per_second > 0.0 {
355            requests_per_second
356        } else {
357            1.0
358        };
359        let burst_capacity = burst_capacity.max(1);
360
361        Self {
362            token_bucket: Arc::new(RwLock::new(TokenBucket {
363                tokens: burst_capacity as f64,
364                capacity: burst_capacity as f64,
365                refill_rate,
366                last_refill: Instant::now(),
367            })),
368        }
369    }
370
371    /// Acquire a permit for making a request
372    pub async fn acquire(&self) -> Result<()> {
373        loop {
374            let wait_time = {
375                let mut bucket = self.token_bucket.write().await;
376                Self::refill_tokens(&mut bucket);
377
378                if bucket.tokens >= 1.0 {
379                    bucket.tokens -= 1.0;
380                    return Ok(());
381                }
382
383                let wait_secs = (1.0 - bucket.tokens) / bucket.refill_rate;
384                Duration::try_from_secs_f64(wait_secs).unwrap_or(Duration::from_secs(60))
385            };
386
387            tokio::time::sleep(wait_time).await;
388        }
389    }
390
391    /// Refill token bucket based on elapsed time
392    fn refill_tokens(bucket: &mut TokenBucket) {
393        let now = Instant::now();
394        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
395
396        let tokens_to_add = elapsed * bucket.refill_rate;
397        bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.capacity);
398        bucket.last_refill = now;
399    }
400}
401
402impl OptimizedLLMClient {
403    pub fn new(
404        pool_size: usize,
405        cache_size: usize,
406        requests_per_second: f64,
407        burst_capacity: usize,
408    ) -> Self {
409        Self {
410            connection_pool: Arc::new(ConnectionPool::new(pool_size)),
411            request_batcher: Arc::new(RequestBatcher::new(Duration::from_millis(100), 10)),
412            response_cache: Arc::new(RwLock::new(lru::LruCache::new(
413                std::num::NonZeroUsize::new(cache_size).unwrap_or(std::num::NonZeroUsize::MIN),
414            ))),
415            rate_limiter: Arc::new(RateLimiter::new(requests_per_second, burst_capacity)),
416            metrics: Arc::new(RwLock::new(ClientMetrics::default())),
417        }
418    }
419
420    /// Make an optimized LLM request with caching and batching
421    pub async fn chat_optimized(
422        &self,
423        request: OptimizedRequest,
424    ) -> Result<OptimizedResponse, LLMError> {
425        let start_time = Instant::now();
426
427        // Generate cache key
428        let cache_key = self.generate_cache_key(&request);
429
430        // Check cache first
431        let cached_response = {
432            let cache = self.response_cache.read().await;
433            cache
434                .peek(&cache_key)
435                .filter(|cached| cached.cached_at.elapsed() < cached.ttl)
436                .map(|cached| cached.response.clone())
437        };
438        if let Some(response) = cached_response {
439            self.metrics.write().await.cache_hits += 1;
440            return Ok(response);
441        }
442
443        self.request_batcher.start_processing().await;
444
445        // Acquire rate limit permit
446        self.rate_limiter
447            .acquire()
448            .await
449            .map_err(|_e| LLMError::RateLimit { metadata: None })?;
450
451        // Create batched request
452        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
453        let batched_request = BatchedRequest {
454            id: uuid::Uuid::new_v4().to_string(),
455            request,
456            response_tx,
457            submitted_at: start_time,
458        };
459
460        // Add to batch queue
461        self.request_batcher
462            .add_request(batched_request)
463            .await
464            .map_err(|e| LLMError::InvalidRequest {
465                message: e.to_string(),
466                metadata: None,
467            })?;
468
469        // Wait for response
470        let response = response_rx.await.map_err(|e| LLMError::InvalidRequest {
471            message: e.to_string(),
472            metadata: None,
473        })??;
474
475        // Cache successful response
476        let cached_response = CachedResponse {
477            response: response.clone(),
478            cached_at: Instant::now(),
479            ttl: Duration::from_secs(300), // 5 minutes
480        };
481
482        self.response_cache
483            .write()
484            .await
485            .put(cache_key, cached_response);
486
487        // Update metrics
488        let execution_time = start_time.elapsed();
489        let mut metrics = self.metrics.write().await;
490        metrics.total_requests += 1;
491
492        // Update average response time using exponential moving average
493        let alpha = 0.1;
494        metrics.avg_response_time_ms = alpha * execution_time.as_millis() as f64
495            + (1.0 - alpha) * metrics.avg_response_time_ms;
496
497        Ok(response)
498    }
499
500    /// Generate cache key for request
501    fn generate_cache_key(&self, request: &OptimizedRequest) -> String {
502        use std::collections::hash_map::DefaultHasher;
503        use std::hash::{Hash, Hasher};
504
505        let mut hasher = DefaultHasher::new();
506        request.model.hash(&mut hasher);
507        request.temperature.map(f32::to_bits).hash(&mut hasher);
508        request.max_tokens.hash(&mut hasher);
509
510        for message in &request.messages {
511            message.to_string().hash(&mut hasher);
512        }
513
514        format!("{:x}", hasher.finish())
515    }
516
517    /// Start the client's background processing
518    pub async fn start(&self) -> Result<()> {
519        self.request_batcher.start_processing().await;
520        Ok(())
521    }
522
523    pub async fn shutdown(&self) -> Result<()> {
524        self.request_batcher.shutdown_processing().await;
525        Ok(())
526    }
527
528    /// Get current client metrics
529    pub async fn get_metrics(&self) -> ClientMetrics {
530        let mut metrics = self.metrics.read().await.clone();
531        metrics.connection_pool_utilization = self.connection_pool.utilization().await;
532        metrics
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_cache_key_includes_generation_settings() {
542        let client = OptimizedLLMClient::new(1, 16, 10.0, 1);
543        let base_request = OptimizedRequest {
544            model: "model-a".to_string(),
545            messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
546            temperature: Some(0.2),
547            max_tokens: Some(128),
548        };
549        let different_temperature = OptimizedRequest {
550            temperature: Some(0.8),
551            ..base_request.clone()
552        };
553        let different_max_tokens = OptimizedRequest {
554            max_tokens: Some(256),
555            ..base_request.clone()
556        };
557
558        assert_ne!(
559            client.generate_cache_key(&base_request),
560            client.generate_cache_key(&different_temperature)
561        );
562        assert_ne!(
563            client.generate_cache_key(&base_request),
564            client.generate_cache_key(&different_max_tokens)
565        );
566    }
567
568    #[test]
569    fn test_batch_key_includes_generation_settings() {
570        let batcher = RequestBatcher::new(Duration::from_millis(100), 10);
571        let base_request = OptimizedRequest {
572            model: "model-a".to_string(),
573            messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
574            temperature: Some(0.2),
575            max_tokens: Some(128),
576        };
577        let different_request = OptimizedRequest {
578            temperature: Some(0.8),
579            ..base_request.clone()
580        };
581
582        assert_ne!(
583            batcher.generate_batch_key(&base_request),
584            batcher.generate_batch_key(&different_request)
585        );
586    }
587
588    #[tokio::test]
589    async fn test_chat_optimized_starts_batch_processing_automatically() {
590        let client = OptimizedLLMClient::new(1, 16, 10.0, 1);
591        let response = tokio::time::timeout(
592            Duration::from_secs(1),
593            client.chat_optimized(OptimizedRequest {
594                model: "model-a".to_string(),
595                messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
596                temperature: Some(0.2),
597                max_tokens: Some(128),
598            }),
599        )
600        .await
601        .expect("request should complete without explicit start")
602        .expect("request should succeed");
603
604        assert_eq!(response.content, "Batched response");
605    }
606
607    #[tokio::test]
608    async fn rate_limiter_zero_burst_capacity_still_allows_progress() {
609        let limiter = RateLimiter::new(10.0, 0);
610
611        tokio::time::timeout(Duration::from_millis(100), limiter.acquire())
612            .await
613            .expect("rate limiter should not stall with zero configured burst")
614            .expect("rate limiter acquire should succeed");
615
616        assert_eq!(limiter.token_bucket.read().await.capacity, 1.0);
617    }
618}