Skip to main content

skilllite_agent/llm/
mod.rs

1//! LLM HTTP client for chat completions with tool calling.
2//!
3//! Supports two API formats:
4//!   - **OpenAI-compatible**: `/chat/completions` (GPT-4, DeepSeek, Qwen, etc.)
5//!   - **Claude Native**: `/v1/messages` (Anthropic Claude)
6//!
7//! Auto-detects which API to use based on model name or API base URL.
8//!
9//! Ported from Python `AgenticLoop._call_openai` / `_call_claude`.
10
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14
15use super::types::{safe_truncate, ChatMessage, EventSink, ToolCall, ToolDefinition, ToolFormat};
16
17mod claude;
18mod openai;
19
20#[cfg(test)]
21mod tests;
22
23/// Detect API format from model name or API base.
24pub fn detect_tool_format(model: &str, api_base: &str) -> ToolFormat {
25    let model_lower = model.to_lowercase();
26    let base_lower = api_base.to_lowercase();
27
28    if model_lower.starts_with("claude")
29        || base_lower.contains("anthropic")
30        || base_lower.contains("claude")
31    {
32        ToolFormat::Claude
33    } else {
34        ToolFormat::OpenAI
35    }
36}
37
38/// LLM client supporting both OpenAI and Claude API formats.
39pub struct LlmClient {
40    http: reqwest::Client,
41    api_base: String,
42    api_key: String,
43}
44
45impl LlmClient {
46    pub fn new(api_base: &str, api_key: &str) -> Result<Self> {
47        let http = reqwest::Client::builder()
48            .timeout(std::time::Duration::from_secs(300))
49            .build()
50            .context("build HTTP client for LLM")?;
51        Ok(Self {
52            http,
53            api_base: api_base.trim_end_matches('/').to_string(),
54            api_key: api_key.to_string(),
55        })
56    }
57
58    /// Non-streaming chat completion call (auto-routes based on model/api_base).
59    pub async fn chat_completion(
60        &self,
61        model: &str,
62        messages: &[ChatMessage],
63        tools: Option<&[ToolDefinition]>,
64        temperature: Option<f64>,
65    ) -> Result<ChatCompletionResponse> {
66        let format = detect_tool_format(model, &self.api_base);
67        match format {
68            ToolFormat::Claude => {
69                self.claude_chat_completion(model, messages, tools, temperature)
70                    .await
71            }
72            ToolFormat::OpenAI => {
73                self.openai_chat_completion(model, messages, tools, temperature)
74                    .await
75            }
76        }
77    }
78
79    /// Streaming chat completion call (auto-routes based on model/api_base).
80    pub async fn chat_completion_stream(
81        &self,
82        model: &str,
83        messages: &[ChatMessage],
84        tools: Option<&[ToolDefinition]>,
85        temperature: Option<f64>,
86        event_sink: &mut dyn EventSink,
87    ) -> Result<ChatCompletionResponse> {
88        let format = detect_tool_format(model, &self.api_base);
89        match format {
90            ToolFormat::Claude => {
91                self.claude_chat_completion_stream(model, messages, tools, temperature, event_sink)
92                    .await
93            }
94            ToolFormat::OpenAI => {
95                self.openai_chat_completion_stream(model, messages, tools, temperature, event_sink)
96                    .await
97            }
98        }
99    }
100
101    /// Embed text(s) using OpenAI-compatible /embeddings API.
102    /// Returns one embedding vector per input string. Used when memory_vector feature is enabled.
103    /// If custom_url and custom_key are provided, use them instead of self.api_base/self.api_key.
104    ///
105    /// Handles both OpenAI-standard format (`{"data": [{"embedding": [...]}]}`)
106    /// and Dashscope native format (`{"output": {"embeddings": [{"embedding": [...]}]}}`).
107    #[allow(dead_code)]
108    pub async fn embed(
109        &self,
110        model: &str,
111        texts: &[&str],
112        custom_url: Option<&str>,
113        custom_key: Option<&str>,
114    ) -> Result<Vec<Vec<f32>>> {
115        if texts.is_empty() {
116            return Ok(Vec::new());
117        }
118        let api_base = custom_url.unwrap_or(&self.api_base);
119        let api_key = custom_key.unwrap_or(&self.api_key);
120        let base = api_base.trim_end_matches('/');
121        let url = if api_base.to_lowercase().contains("minimax") {
122            format!("{}/text/embeddings", base)
123        } else {
124            format!("{}/embeddings", base)
125        };
126        let input: Value = if texts.len() == 1 {
127            json!(texts[0])
128        } else {
129            json!(texts.iter().map(|s| s.to_string()).collect::<Vec<_>>())
130        };
131        let body = json!({ "model": model, "input": input });
132        let resp = self
133            .http
134            .post(&url)
135            .header("Authorization", format!("Bearer {}", api_key))
136            .header("Content-Type", "application/json")
137            .json(&body)
138            .send()
139            .await
140            .context("Embedding API request failed")?;
141        let status = resp.status();
142        if !status.is_success() {
143            let body_text = resp.text().await.unwrap_or_default();
144            anyhow::bail!("Embedding API error ({}): {}", status, body_text);
145        }
146        let json: Value = resp
147            .json()
148            .await
149            .context("Failed to parse embedding response")?;
150
151        // Try OpenAI-standard format: {"data": [{"embedding": [...]}]}
152        if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
153            return Self::extract_embeddings_from_items(data);
154        }
155
156        // Fallback: Dashscope native format: {"output": {"embeddings": [{"embedding": [...]}]}}
157        if let Some(items) = json
158            .get("output")
159            .and_then(|o| o.get("embeddings"))
160            .and_then(|e| e.as_array())
161        {
162            tracing::debug!("Embedding response uses Dashscope native format (output.embeddings)");
163            return Self::extract_embeddings_from_items(items);
164        }
165
166        // Log the unexpected response shape for debugging
167        let preview = serde_json::to_string(&json).unwrap_or_default();
168        let preview = &preview[..preview.len().min(500)];
169        anyhow::bail!(
170            "Unexpected embedding response format (no 'data' or 'output.embeddings'): {}",
171            preview
172        )
173    }
174
175    /// Extract embedding vectors from a JSON array of items, each containing an "embedding" field.
176    fn extract_embeddings_from_items(items: &[Value]) -> Result<Vec<Vec<f32>>> {
177        let mut embeddings = Vec::with_capacity(items.len());
178        for item in items {
179            let emb = item
180                .get("embedding")
181                .and_then(|e| e.as_array())
182                .context("Missing 'embedding' in embedding item")?;
183            let vec: Vec<f32> = emb
184                .iter()
185                .filter_map(|v| v.as_f64().map(|f| f as f32))
186                .collect();
187            embeddings.push(vec);
188        }
189        Ok(embeddings)
190    }
191
192    // ═══════════════════════════════════════════════════════════════════════════
193    // OpenAI-compatible API
194    // ═══════════════════════════════════════════════════════════════════════════
195}
196
197// ─── Response types ─────────────────────────────────────────────────────────
198// Fields id/model/usage/index/finish_reason/role are required for API deserialization
199// but not read by our code.
200
201#[derive(Debug, Deserialize)]
202#[allow(dead_code)]
203pub struct ChatCompletionResponse {
204    pub id: String,
205    pub model: String,
206    pub choices: Vec<Choice>,
207    pub usage: Option<Usage>,
208}
209
210#[derive(Debug, Deserialize)]
211#[allow(dead_code)]
212pub struct Choice {
213    pub index: u32,
214    pub message: ChoiceMessage,
215    pub finish_reason: Option<String>,
216}
217
218#[derive(Debug, Deserialize)]
219#[allow(dead_code)]
220pub struct ChoiceMessage {
221    pub role: String,
222    pub content: Option<String>,
223    /// Reasoning/thinking content returned separately by reasoning models
224    /// (e.g. DeepSeek R1 via official API or vLLM with --reasoning-parser).
225    /// When present, `content` already excludes the thinking — no tag stripping needed.
226    #[serde(default)]
227    pub reasoning_content: Option<String>,
228    pub tool_calls: Option<Vec<ToolCall>>,
229}
230
231#[derive(Debug, Deserialize, Serialize)]
232pub struct Usage {
233    pub prompt_tokens: u64,
234    pub completion_tokens: u64,
235    pub total_tokens: u64,
236}
237
238// ─── Helpers ────────────────────────────────────────────────────────────────
239
240/// Check if an error is a context overflow (token limit exceeded).
241/// Ported from Python `_is_context_overflow_error`.
242pub fn is_context_overflow_error(err_msg: &str) -> bool {
243    let lower = err_msg.to_lowercase();
244    lower.contains("context_length_exceeded")
245        || lower.contains("maximum context length")
246        || lower.contains("token limit")
247        || lower.contains("too many tokens")
248        || lower.contains("context window")
249        || lower.contains("max_tokens")
250}
251
252/// Truncate all tool result messages in place to reduce context size.
253/// Ported from Python `_truncate_tool_messages_in_place`.
254pub fn truncate_tool_messages(messages: &mut [ChatMessage], max_chars: usize) {
255    for msg in messages.iter_mut() {
256        if msg.role == "tool" {
257            if let Some(ref mut content) = msg.content {
258                if content.len() > max_chars {
259                    let truncated = format!(
260                        "{}...\n[truncated: {} chars → {}]",
261                        safe_truncate(content, max_chars),
262                        content.len(),
263                        max_chars
264                    );
265                    *content = truncated;
266                }
267            }
268        }
269    }
270}