skilllite_agent/llm/
mod.rs1use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14
15use super::types::{safe_truncate, ChatMessage, EventSink, ToolCall, ToolDefinition, ToolFormat};
16
17mod claude;
18mod openai;
19
20#[cfg(test)]
21mod tests;
22
23pub fn detect_tool_format(model: &str, api_base: &str) -> ToolFormat {
25 let model_lower = model.to_lowercase();
26 let base_lower = api_base.to_lowercase();
27
28 if model_lower.starts_with("claude")
29 || base_lower.contains("anthropic")
30 || base_lower.contains("claude")
31 {
32 ToolFormat::Claude
33 } else {
34 ToolFormat::OpenAI
35 }
36}
37
38pub struct LlmClient {
40 http: reqwest::Client,
41 api_base: String,
42 api_key: String,
43}
44
45impl LlmClient {
46 pub fn new(api_base: &str, api_key: &str) -> Result<Self> {
47 let http = reqwest::Client::builder()
48 .timeout(std::time::Duration::from_secs(300))
49 .build()
50 .context("build HTTP client for LLM")?;
51 Ok(Self {
52 http,
53 api_base: api_base.trim_end_matches('/').to_string(),
54 api_key: api_key.to_string(),
55 })
56 }
57
58 pub async fn chat_completion(
60 &self,
61 model: &str,
62 messages: &[ChatMessage],
63 tools: Option<&[ToolDefinition]>,
64 temperature: Option<f64>,
65 ) -> Result<ChatCompletionResponse> {
66 let format = detect_tool_format(model, &self.api_base);
67 match format {
68 ToolFormat::Claude => {
69 self.claude_chat_completion(model, messages, tools, temperature)
70 .await
71 }
72 ToolFormat::OpenAI => {
73 self.openai_chat_completion(model, messages, tools, temperature)
74 .await
75 }
76 }
77 }
78
79 pub async fn chat_completion_stream(
81 &self,
82 model: &str,
83 messages: &[ChatMessage],
84 tools: Option<&[ToolDefinition]>,
85 temperature: Option<f64>,
86 event_sink: &mut dyn EventSink,
87 ) -> Result<ChatCompletionResponse> {
88 let format = detect_tool_format(model, &self.api_base);
89 match format {
90 ToolFormat::Claude => {
91 self.claude_chat_completion_stream(model, messages, tools, temperature, event_sink)
92 .await
93 }
94 ToolFormat::OpenAI => {
95 self.openai_chat_completion_stream(model, messages, tools, temperature, event_sink)
96 .await
97 }
98 }
99 }
100
101 #[allow(dead_code)]
108 pub async fn embed(
109 &self,
110 model: &str,
111 texts: &[&str],
112 custom_url: Option<&str>,
113 custom_key: Option<&str>,
114 ) -> Result<Vec<Vec<f32>>> {
115 if texts.is_empty() {
116 return Ok(Vec::new());
117 }
118 let api_base = custom_url.unwrap_or(&self.api_base);
119 let api_key = custom_key.unwrap_or(&self.api_key);
120 let base = api_base.trim_end_matches('/');
121 let url = if api_base.to_lowercase().contains("minimax") {
122 format!("{}/text/embeddings", base)
123 } else {
124 format!("{}/embeddings", base)
125 };
126 let input: Value = if texts.len() == 1 {
127 json!(texts[0])
128 } else {
129 json!(texts.iter().map(|s| s.to_string()).collect::<Vec<_>>())
130 };
131 let body = json!({ "model": model, "input": input });
132 let resp = self
133 .http
134 .post(&url)
135 .header("Authorization", format!("Bearer {}", api_key))
136 .header("Content-Type", "application/json")
137 .json(&body)
138 .send()
139 .await
140 .context("Embedding API request failed")?;
141 let status = resp.status();
142 if !status.is_success() {
143 let body_text = resp.text().await.unwrap_or_default();
144 anyhow::bail!("Embedding API error ({}): {}", status, body_text);
145 }
146 let json: Value = resp
147 .json()
148 .await
149 .context("Failed to parse embedding response")?;
150
151 if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
153 return Self::extract_embeddings_from_items(data);
154 }
155
156 if let Some(items) = json
158 .get("output")
159 .and_then(|o| o.get("embeddings"))
160 .and_then(|e| e.as_array())
161 {
162 tracing::debug!("Embedding response uses Dashscope native format (output.embeddings)");
163 return Self::extract_embeddings_from_items(items);
164 }
165
166 let preview = serde_json::to_string(&json).unwrap_or_default();
168 let preview = &preview[..preview.len().min(500)];
169 anyhow::bail!(
170 "Unexpected embedding response format (no 'data' or 'output.embeddings'): {}",
171 preview
172 )
173 }
174
175 fn extract_embeddings_from_items(items: &[Value]) -> Result<Vec<Vec<f32>>> {
177 let mut embeddings = Vec::with_capacity(items.len());
178 for item in items {
179 let emb = item
180 .get("embedding")
181 .and_then(|e| e.as_array())
182 .context("Missing 'embedding' in embedding item")?;
183 let vec: Vec<f32> = emb
184 .iter()
185 .filter_map(|v| v.as_f64().map(|f| f as f32))
186 .collect();
187 embeddings.push(vec);
188 }
189 Ok(embeddings)
190 }
191
192 }
196
197#[derive(Debug, Deserialize)]
202#[allow(dead_code)]
203pub struct ChatCompletionResponse {
204 pub id: String,
205 pub model: String,
206 pub choices: Vec<Choice>,
207 pub usage: Option<Usage>,
208}
209
210#[derive(Debug, Deserialize)]
211#[allow(dead_code)]
212pub struct Choice {
213 pub index: u32,
214 pub message: ChoiceMessage,
215 pub finish_reason: Option<String>,
216}
217
218#[derive(Debug, Deserialize)]
219#[allow(dead_code)]
220pub struct ChoiceMessage {
221 pub role: String,
222 pub content: Option<String>,
223 #[serde(default)]
227 pub reasoning_content: Option<String>,
228 pub tool_calls: Option<Vec<ToolCall>>,
229}
230
231#[derive(Debug, Deserialize, Serialize)]
232pub struct Usage {
233 pub prompt_tokens: u64,
234 pub completion_tokens: u64,
235 pub total_tokens: u64,
236}
237
238pub fn is_context_overflow_error(err_msg: &str) -> bool {
243 let lower = err_msg.to_lowercase();
244 lower.contains("context_length_exceeded")
245 || lower.contains("maximum context length")
246 || lower.contains("token limit")
247 || lower.contains("too many tokens")
248 || lower.contains("context window")
249 || lower.contains("max_tokens")
250}
251
252pub fn truncate_tool_messages(messages: &mut [ChatMessage], max_chars: usize) {
255 for msg in messages.iter_mut() {
256 if msg.role == "tool" {
257 if let Some(ref mut content) = msg.content {
258 if content.len() > max_chars {
259 let truncated = format!(
260 "{}...\n[truncated: {} chars → {}]",
261 safe_truncate(content, max_chars),
262 content.len(),
263 max_chars
264 );
265 *content = truncated;
266 }
267 }
268 }
269 }
270}