Skip to main content

vectorless/llm/
client.rs

1// Copyright (c) 2026 vectorless developers
2// SPDX-License-Identifier: Apache-2.0
3
4//! Unified LLM client with retry and concurrency support.
5
6use async_openai::{
7    config::OpenAIConfig,
8    types::chat::{
9        ChatCompletionRequestSystemMessage,
10        ChatCompletionRequestUserMessage,
11        CreateChatCompletionRequestArgs,
12    },
13    Client,
14};
15use serde::de::DeserializeOwned;
16use std::borrow::Cow;
17use std::sync::Arc;
18use tracing::{debug, instrument};
19
20use super::config::LlmConfig;
21use super::error::{LlmError, LlmResult};
22use super::fallback::FallbackChain;
23use super::retry::with_retry;
24use crate::throttle::ConcurrencyController;
25
26/// Unified LLM client.
27///
28/// This client provides:
29/// - Unified interface for all LLM operations
30/// - Automatic retry with exponential backoff
31/// - Rate limiting and concurrency control
32/// - JSON response parsing
33/// - Error classification
34/// - Graceful fallback on errors
35///
36/// # Example
37///
38/// ```rust,no_run
39/// use vectorless::llm::{LlmClient, LlmConfig};
40///
41/// # #[tokio::main]
42/// # async fn main() -> vectorless::llm::LlmResult<()> {
43/// let config = LlmConfig::new("gpt-4o-mini");
44/// let client = LlmClient::new(config);
45///
46/// // Simple completion
47/// let response = client.complete("You are helpful.", "Hello!").await?;
48/// println!("Response: {}", response);
49///
50/// // JSON completion
51/// #[derive(serde::Deserialize)]
52/// struct Answer {
53///     answer: String,
54/// }
55/// let answer: Answer = client.complete_json(
56///     "You answer questions in JSON.",
57///     "What is 2+2?"
58/// ).await?;
59/// # Ok(())
60/// # }
61/// ```
62#[derive(Clone)]
63pub struct LlmClient {
64    config: LlmConfig,
65    concurrency: Option<Arc<ConcurrencyController>>,
66    fallback: Option<Arc<FallbackChain>>,
67}
68
69impl std::fmt::Debug for LlmClient {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("LlmClient")
72            .field("model", &self.config.model)
73            .field("endpoint", &self.config.endpoint)
74            .field("concurrency", &self.concurrency.as_ref().map(|c| format!("{:?}", c)))
75            .field("fallback_enabled", &self.fallback.is_some())
76            .finish()
77    }
78}
79
80impl LlmClient {
81    /// Create a new LLM client with the given configuration.
82    pub fn new(config: LlmConfig) -> Self {
83        Self {
84            config,
85            concurrency: None,
86            fallback: None,
87        }
88    }
89
90    /// Create a client with default configuration.
91    pub fn with_defaults() -> Self {
92        Self::new(LlmConfig::default())
93    }
94
95    /// Create a client for a specific model.
96    pub fn for_model(model: impl Into<String>) -> Self {
97        Self::new(LlmConfig::new(model))
98    }
99
100    /// Add concurrency control to the client.
101    ///
102    /// # Example
103    ///
104    /// ```rust,no_run
105    /// use vectorless::llm::LlmClient;
106    /// use vectorless::throttle::{ConcurrencyController, ConcurrencyConfig};
107    ///
108    /// let config = ConcurrencyConfig::new()
109    ///     .with_max_concurrent_requests(10)
110    ///     .with_requests_per_minute(500);
111    ///
112    /// let client = LlmClient::for_model("gpt-4o-mini")
113    ///     .with_concurrency(ConcurrencyController::new(config));
114    /// ```
115    pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self {
116        self.concurrency = Some(Arc::new(controller));
117        self
118    }
119
120    /// Add concurrency control from an existing Arc.
121    pub fn with_shared_concurrency(mut self, controller: Arc<ConcurrencyController>) -> Self {
122        self.concurrency = Some(controller);
123        self
124    }
125
126    /// Add fallback chain for error recovery.
127    ///
128    /// # Example
129    ///
130    /// ```rust
131    /// use vectorless::llm::{LlmClient, FallbackChain, FallbackConfig};
132    ///
133    /// let fallback = FallbackConfig::default();
134    /// let client = LlmClient::for_model("gpt-4o")
135    ///     .with_fallback(FallbackChain::new(fallback));
136    ///
137    /// assert!(client.fallback().is_some());
138    /// ```
139    pub fn with_fallback(mut self, chain: FallbackChain) -> Self {
140        self.fallback = Some(Arc::new(chain));
141        self
142    }
143
144    /// Add fallback chain from an existing Arc.
145    pub fn with_shared_fallback(mut self, chain: Arc<FallbackChain>) -> Self {
146        self.fallback = Some(chain);
147        self
148    }
149
150    /// Get the configuration.
151    pub fn config(&self) -> &LlmConfig {
152        &self.config
153    }
154
155    /// Get the concurrency controller (if any).
156    pub fn concurrency(&self) -> Option<&ConcurrencyController> {
157        self.concurrency.as_deref()
158    }
159
160    /// Get the fallback chain (if any).
161    pub fn fallback(&self) -> Option<&FallbackChain> {
162        self.fallback.as_deref()
163    }
164
165    /// Complete a prompt with system and user messages.
166    ///
167    /// This method includes:
168    /// - Automatic rate limiting (if configured)
169    /// - Automatic retry with exponential backoff
170    #[instrument(skip(self, system, user), fields(model = %self.config.model))]
171    pub async fn complete(&self, system: &str, user: &str) -> LlmResult<String> {
172        with_retry(&self.config.retry, || async {
173            self.complete_once(system, user).await
174        }).await
175    }
176
177    /// Complete a prompt with custom max tokens.
178    pub async fn complete_with_max_tokens(
179        &self,
180        system: &str,
181        user: &str,
182        max_tokens: u16,
183    ) -> LlmResult<String> {
184        with_retry(&self.config.retry, || async {
185            self.complete_once_with_max_tokens(system, user, max_tokens).await
186        }).await
187    }
188
189    /// Complete a prompt and parse the response as JSON.
190    ///
191    /// This method handles:
192    /// - JSON extraction from markdown code blocks
193    /// - Bracket matching for nested JSON
194    ///
195    /// # Example
196    ///
197    /// ```rust,no_run
198    /// # use vectorless::llm::{LlmClient, LlmConfig};
199    /// # #[tokio::main]
200    /// # async fn main() -> vectorless::llm::LlmResult<()> {
201    /// #[derive(serde::Deserialize)]
202    /// struct TocEntry {
203    ///     title: String,
204    ///     page: usize,
205    /// }
206    ///
207    /// let client = LlmClient::for_model("gpt-4o-mini");
208    /// let entries: Vec<TocEntry> = client.complete_json(
209    ///     "Extract TOC entries as JSON array.",
210    ///     "Chapter 1: Introduction ... 5"
211    /// ).await?;
212    /// # Ok(())
213    /// # }
214    /// ```
215    pub async fn complete_json<T: DeserializeOwned>(
216        &self,
217        system: &str,
218        user: &str,
219    ) -> LlmResult<T> {
220        let response = self.complete(system, user).await?;
221        self.parse_json(&response)
222    }
223
224    /// Complete a prompt and parse the response as JSON with custom max tokens.
225    pub async fn complete_json_with_max_tokens<T: DeserializeOwned>(
226        &self,
227        system: &str,
228        user: &str,
229        max_tokens: u16,
230    ) -> LlmResult<T> {
231        let response = self.complete_with_max_tokens(system, user, max_tokens).await?;
232        self.parse_json(&response)
233    }
234
235    /// Single completion attempt (no retry).
236    async fn complete_once(&self, system: &str, user: &str) -> LlmResult<String> {
237        // Acquire concurrency permit (rate limiter + semaphore)
238        let _permit = if let Some(ref cc) = self.concurrency {
239            Some(cc.acquire().await)
240        } else {
241            None
242        };
243
244        let api_key = self.config.get_api_key()
245            .ok_or_else(|| LlmError::Config(
246                "No API key found. Set OPENAI_API_KEY environment variable.".to_string()
247            ))?;
248
249        let endpoint = self.config.auto_detect_endpoint();
250        let model = self.config.auto_detect_model();
251
252        println!("Using OpenAI API endpoint: {}", endpoint);
253        println!("Using OpenAI model: {}", model);
254
255        let openai_config = OpenAIConfig::new()
256            .with_api_key(api_key)
257            .with_api_base(&endpoint);
258
259        let client = Client::with_config(openai_config);
260
261        // Truncate user prompt if too long
262        let truncated = self.truncate_prompt(user);
263
264        let request = CreateChatCompletionRequestArgs::default()
265            .model(&model)
266            .messages([
267                ChatCompletionRequestSystemMessage::from(system).into(),
268                ChatCompletionRequestUserMessage::from(truncated).into(),
269            ])
270            // .max_tokens(self.config.max_tokens as u16)
271            .temperature(self.config.temperature)
272            .build()
273            .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
274
275        debug!("Sending LLM request to {} with model {}", endpoint, model);
276
277        let response = client.chat().create(request).await
278            .map_err(|e| {
279                let msg = e.to_string();
280                LlmError::from_api_message(&msg)
281            })?;
282
283        let content = response
284            .choices
285            .first()
286            .and_then(|choice| choice.message.content.clone())
287            .ok_or(LlmError::NoContent)?;
288
289        debug!("LLM response length: {} chars", content.len());
290
291        Ok(content)
292    }
293
294    /// Single completion with custom max tokens.
295    async fn complete_once_with_max_tokens(
296        &self,
297        system: &str,
298        user: &str,
299        max_tokens: u16,
300    ) -> LlmResult<String> {
301        // Acquire concurrency permit
302        let _permit = if let Some(ref cc) = self.concurrency {
303            Some(cc.acquire().await)
304        } else {
305            None
306        };
307
308        let api_key = self.config.get_api_key()
309            .ok_or_else(|| LlmError::Config(
310                "No API key found. Set OPENAI_API_KEY environment variable.".to_string()
311            ))?;
312
313        let endpoint = self.config.auto_detect_endpoint();
314        let model = self.config.auto_detect_model();
315
316        let openai_config = OpenAIConfig::new()
317            .with_api_key(api_key)
318            .with_api_base(&endpoint);
319
320        let client = Client::with_config(openai_config);
321
322        let truncated = self.truncate_prompt(user);
323
324        let request = CreateChatCompletionRequestArgs::default()
325            .model(&model)
326            .messages([
327                ChatCompletionRequestSystemMessage::from(system).into(),
328                ChatCompletionRequestUserMessage::from(truncated).into(),
329            ])
330            // .max_tokens(max_tokens)
331            .temperature(self.config.temperature)
332            .build()
333            .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
334
335        let response = client.chat().create(request).await
336            .map_err(|e| {
337                let msg = e.to_string();
338                eprintln!("[LLM ERROR] API error: {}", msg);
339                LlmError::from_api_message(&msg)
340            })?;
341
342        // Debug: log response structure
343        eprintln!("[LLM DEBUG] Response: {} choices", response.choices.len());
344        if let Some(choice) = response.choices.first() {
345            eprintln!("[LLM DEBUG] First choice: finish_reason={:?}, has_content={}",
346                choice.finish_reason,
347                choice.message.content.is_some()
348            );
349        }
350
351        let content = response
352            .choices
353            .first()
354            .and_then(|choice| choice.message.content.clone())
355            .ok_or_else(|| {
356                eprintln!("[LLM ERROR] Response has no content");
357                LlmError::NoContent
358            })?;
359
360        if content.is_empty() {
361            eprintln!("[LLM WARN] Returned empty content for model: {}", model);
362        } else {
363            eprintln!("[LLM DEBUG] Content length: {} chars", content.len());
364        }
365
366        Ok(content)
367    }
368
369    /// Truncate a prompt to a reasonable length.
370    fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str {
371        // Roughly 4 chars per token, limit to ~30k chars
372        const MAX_CHARS: usize = 30000;
373        if text.len() > MAX_CHARS {
374            &text[..MAX_CHARS]
375        } else {
376            text
377        }
378    }
379
380    /// Parse JSON from LLM response.
381    fn parse_json<T: DeserializeOwned>(&self, text: &str) -> LlmResult<T> {
382        let json_text = self.extract_json(text);
383        serde_json::from_str(&json_text)
384            .map_err(|e| LlmError::Parse(format!("Failed to parse JSON: {}. Response: {}", e, text)))
385    }
386
387    /// Extract JSON from text (handles markdown code blocks).
388    fn extract_json<'a>(&self, text: &'a str) -> Cow<'a, str> {
389        let text = text.trim();
390
391        // Try markdown code block first
392        if text.starts_with("```") {
393            // Find the end of the first line (language identifier)
394            if let Some(start) = text.find('\n') {
395                let rest = &text[start + 1..];
396                if let Some(end) = rest.find("```") {
397                    return Cow::Borrowed(rest[..end].trim());
398                }
399            }
400        }
401
402        // Try to find JSON array or object
403        if text.starts_with('[') || text.starts_with('{') {
404            let open = text.chars().next().unwrap();
405            let close = if open == '[' { ']' } else { '}' };
406
407            let mut depth = 0;
408            for (i, ch) in text.char_indices() {
409                match ch {
410                    c if c == open => depth += 1,
411                    c if c == close => {
412                        depth -= 1;
413                        if depth == 0 {
414                            return Cow::Borrowed(&text[..=i]);
415                        }
416                    }
417                    _ => {}
418                }
419            }
420        }
421
422        Cow::Borrowed(text)
423    }
424}
425
426impl Default for LlmClient {
427    fn default() -> Self {
428        Self::with_defaults()
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_extract_json_plain() {
438        let client = LlmClient::with_defaults();
439
440        let json = client.extract_json(r#"{"key": "value"}"#);
441        assert_eq!(json, r#"{"key": "value"}"#);
442    }
443
444    #[test]
445    fn test_extract_json_code_block() {
446        let client = LlmClient::with_defaults();
447
448        let json = client.extract_json(r#"```json
449{"key": "value"}
450```"#);
451        assert_eq!(json, r#"{"key": "value"}"#);
452    }
453
454    #[test]
455    fn test_extract_json_array() {
456        let client = LlmClient::with_defaults();
457
458        let json = client.extract_json(r#"[1, 2, 3]"#);
459        assert_eq!(json, r#"[1, 2, 3]"#);
460    }
461
462    #[test]
463    fn test_extract_json_nested() {
464        let client = LlmClient::with_defaults();
465
466        let json = client.extract_json(r#"{"outer": {"inner": 1}}"#);
467        assert_eq!(json, r#"{"outer": {"inner": 1}}"#);
468    }
469
470    #[test]
471    fn test_client_creation() {
472        let client = LlmClient::for_model("gpt-4o");
473        assert_eq!(client.config.model, "gpt-4o");
474    }
475
476    #[test]
477    fn test_client_with_concurrency() {
478        use crate::throttle::ConcurrencyConfig;
479
480        let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
481        let client = LlmClient::for_model("gpt-4o-mini")
482            .with_concurrency(controller);
483
484        assert!(client.concurrency.is_some());
485    }
486}