1use async_openai::{
7 config::OpenAIConfig,
8 types::chat::{
9 ChatCompletionRequestSystemMessage,
10 ChatCompletionRequestUserMessage,
11 CreateChatCompletionRequestArgs,
12 },
13 Client,
14};
15use serde::de::DeserializeOwned;
16use std::borrow::Cow;
17use std::sync::Arc;
18use tracing::{debug, instrument};
19
20use super::config::LlmConfig;
21use super::error::{LlmError, LlmResult};
22use super::fallback::FallbackChain;
23use super::retry::with_retry;
24use crate::throttle::ConcurrencyController;
25
26#[derive(Clone)]
63pub struct LlmClient {
64 config: LlmConfig,
65 concurrency: Option<Arc<ConcurrencyController>>,
66 fallback: Option<Arc<FallbackChain>>,
67}
68
69impl std::fmt::Debug for LlmClient {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("LlmClient")
72 .field("model", &self.config.model)
73 .field("endpoint", &self.config.endpoint)
74 .field("concurrency", &self.concurrency.as_ref().map(|c| format!("{:?}", c)))
75 .field("fallback_enabled", &self.fallback.is_some())
76 .finish()
77 }
78}
79
80impl LlmClient {
81 pub fn new(config: LlmConfig) -> Self {
83 Self {
84 config,
85 concurrency: None,
86 fallback: None,
87 }
88 }
89
90 pub fn with_defaults() -> Self {
92 Self::new(LlmConfig::default())
93 }
94
95 pub fn for_model(model: impl Into<String>) -> Self {
97 Self::new(LlmConfig::new(model))
98 }
99
100 pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self {
116 self.concurrency = Some(Arc::new(controller));
117 self
118 }
119
120 pub fn with_shared_concurrency(mut self, controller: Arc<ConcurrencyController>) -> Self {
122 self.concurrency = Some(controller);
123 self
124 }
125
126 pub fn with_fallback(mut self, chain: FallbackChain) -> Self {
140 self.fallback = Some(Arc::new(chain));
141 self
142 }
143
144 pub fn with_shared_fallback(mut self, chain: Arc<FallbackChain>) -> Self {
146 self.fallback = Some(chain);
147 self
148 }
149
150 pub fn config(&self) -> &LlmConfig {
152 &self.config
153 }
154
155 pub fn concurrency(&self) -> Option<&ConcurrencyController> {
157 self.concurrency.as_deref()
158 }
159
160 pub fn fallback(&self) -> Option<&FallbackChain> {
162 self.fallback.as_deref()
163 }
164
165 #[instrument(skip(self, system, user), fields(model = %self.config.model))]
171 pub async fn complete(&self, system: &str, user: &str) -> LlmResult<String> {
172 with_retry(&self.config.retry, || async {
173 self.complete_once(system, user).await
174 }).await
175 }
176
177 pub async fn complete_with_max_tokens(
179 &self,
180 system: &str,
181 user: &str,
182 max_tokens: u16,
183 ) -> LlmResult<String> {
184 with_retry(&self.config.retry, || async {
185 self.complete_once_with_max_tokens(system, user, max_tokens).await
186 }).await
187 }
188
189 pub async fn complete_json<T: DeserializeOwned>(
216 &self,
217 system: &str,
218 user: &str,
219 ) -> LlmResult<T> {
220 let response = self.complete(system, user).await?;
221 self.parse_json(&response)
222 }
223
224 pub async fn complete_json_with_max_tokens<T: DeserializeOwned>(
226 &self,
227 system: &str,
228 user: &str,
229 max_tokens: u16,
230 ) -> LlmResult<T> {
231 let response = self.complete_with_max_tokens(system, user, max_tokens).await?;
232 self.parse_json(&response)
233 }
234
235 async fn complete_once(&self, system: &str, user: &str) -> LlmResult<String> {
237 let _permit = if let Some(ref cc) = self.concurrency {
239 Some(cc.acquire().await)
240 } else {
241 None
242 };
243
244 let api_key = self.config.get_api_key()
245 .ok_or_else(|| LlmError::Config(
246 "No API key found. Set OPENAI_API_KEY environment variable.".to_string()
247 ))?;
248
249 let endpoint = self.config.auto_detect_endpoint();
250 let model = self.config.auto_detect_model();
251
252 println!("Using OpenAI API endpoint: {}", endpoint);
253 println!("Using OpenAI model: {}", model);
254
255 let openai_config = OpenAIConfig::new()
256 .with_api_key(api_key)
257 .with_api_base(&endpoint);
258
259 let client = Client::with_config(openai_config);
260
261 let truncated = self.truncate_prompt(user);
263
264 let request = CreateChatCompletionRequestArgs::default()
265 .model(&model)
266 .messages([
267 ChatCompletionRequestSystemMessage::from(system).into(),
268 ChatCompletionRequestUserMessage::from(truncated).into(),
269 ])
270 .temperature(self.config.temperature)
272 .build()
273 .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
274
275 debug!("Sending LLM request to {} with model {}", endpoint, model);
276
277 let response = client.chat().create(request).await
278 .map_err(|e| {
279 let msg = e.to_string();
280 LlmError::from_api_message(&msg)
281 })?;
282
283 let content = response
284 .choices
285 .first()
286 .and_then(|choice| choice.message.content.clone())
287 .ok_or(LlmError::NoContent)?;
288
289 debug!("LLM response length: {} chars", content.len());
290
291 Ok(content)
292 }
293
294 async fn complete_once_with_max_tokens(
296 &self,
297 system: &str,
298 user: &str,
299 max_tokens: u16,
300 ) -> LlmResult<String> {
301 let _permit = if let Some(ref cc) = self.concurrency {
303 Some(cc.acquire().await)
304 } else {
305 None
306 };
307
308 let api_key = self.config.get_api_key()
309 .ok_or_else(|| LlmError::Config(
310 "No API key found. Set OPENAI_API_KEY environment variable.".to_string()
311 ))?;
312
313 let endpoint = self.config.auto_detect_endpoint();
314 let model = self.config.auto_detect_model();
315
316 let openai_config = OpenAIConfig::new()
317 .with_api_key(api_key)
318 .with_api_base(&endpoint);
319
320 let client = Client::with_config(openai_config);
321
322 let truncated = self.truncate_prompt(user);
323
324 let request = CreateChatCompletionRequestArgs::default()
325 .model(&model)
326 .messages([
327 ChatCompletionRequestSystemMessage::from(system).into(),
328 ChatCompletionRequestUserMessage::from(truncated).into(),
329 ])
330 .temperature(self.config.temperature)
332 .build()
333 .map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
334
335 let response = client.chat().create(request).await
336 .map_err(|e| {
337 let msg = e.to_string();
338 eprintln!("[LLM ERROR] API error: {}", msg);
339 LlmError::from_api_message(&msg)
340 })?;
341
342 eprintln!("[LLM DEBUG] Response: {} choices", response.choices.len());
344 if let Some(choice) = response.choices.first() {
345 eprintln!("[LLM DEBUG] First choice: finish_reason={:?}, has_content={}",
346 choice.finish_reason,
347 choice.message.content.is_some()
348 );
349 }
350
351 let content = response
352 .choices
353 .first()
354 .and_then(|choice| choice.message.content.clone())
355 .ok_or_else(|| {
356 eprintln!("[LLM ERROR] Response has no content");
357 LlmError::NoContent
358 })?;
359
360 if content.is_empty() {
361 eprintln!("[LLM WARN] Returned empty content for model: {}", model);
362 } else {
363 eprintln!("[LLM DEBUG] Content length: {} chars", content.len());
364 }
365
366 Ok(content)
367 }
368
369 fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str {
371 const MAX_CHARS: usize = 30000;
373 if text.len() > MAX_CHARS {
374 &text[..MAX_CHARS]
375 } else {
376 text
377 }
378 }
379
380 fn parse_json<T: DeserializeOwned>(&self, text: &str) -> LlmResult<T> {
382 let json_text = self.extract_json(text);
383 serde_json::from_str(&json_text)
384 .map_err(|e| LlmError::Parse(format!("Failed to parse JSON: {}. Response: {}", e, text)))
385 }
386
387 fn extract_json<'a>(&self, text: &'a str) -> Cow<'a, str> {
389 let text = text.trim();
390
391 if text.starts_with("```") {
393 if let Some(start) = text.find('\n') {
395 let rest = &text[start + 1..];
396 if let Some(end) = rest.find("```") {
397 return Cow::Borrowed(rest[..end].trim());
398 }
399 }
400 }
401
402 if text.starts_with('[') || text.starts_with('{') {
404 let open = text.chars().next().unwrap();
405 let close = if open == '[' { ']' } else { '}' };
406
407 let mut depth = 0;
408 for (i, ch) in text.char_indices() {
409 match ch {
410 c if c == open => depth += 1,
411 c if c == close => {
412 depth -= 1;
413 if depth == 0 {
414 return Cow::Borrowed(&text[..=i]);
415 }
416 }
417 _ => {}
418 }
419 }
420 }
421
422 Cow::Borrowed(text)
423 }
424}
425
426impl Default for LlmClient {
427 fn default() -> Self {
428 Self::with_defaults()
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_extract_json_plain() {
438 let client = LlmClient::with_defaults();
439
440 let json = client.extract_json(r#"{"key": "value"}"#);
441 assert_eq!(json, r#"{"key": "value"}"#);
442 }
443
444 #[test]
445 fn test_extract_json_code_block() {
446 let client = LlmClient::with_defaults();
447
448 let json = client.extract_json(r#"```json
449{"key": "value"}
450```"#);
451 assert_eq!(json, r#"{"key": "value"}"#);
452 }
453
454 #[test]
455 fn test_extract_json_array() {
456 let client = LlmClient::with_defaults();
457
458 let json = client.extract_json(r#"[1, 2, 3]"#);
459 assert_eq!(json, r#"[1, 2, 3]"#);
460 }
461
462 #[test]
463 fn test_extract_json_nested() {
464 let client = LlmClient::with_defaults();
465
466 let json = client.extract_json(r#"{"outer": {"inner": 1}}"#);
467 assert_eq!(json, r#"{"outer": {"inner": 1}}"#);
468 }
469
470 #[test]
471 fn test_client_creation() {
472 let client = LlmClient::for_model("gpt-4o");
473 assert_eq!(client.config.model, "gpt-4o");
474 }
475
476 #[test]
477 fn test_client_with_concurrency() {
478 use crate::throttle::ConcurrencyConfig;
479
480 let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
481 let client = LlmClient::for_model("gpt-4o-mini")
482 .with_concurrency(controller);
483
484 assert!(client.concurrency.is_some());
485 }
486}