1use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6const DEFAULT_INFER_RETRIES: usize = 3;
8
9#[derive(Clone)]
11pub struct LlmClient {
12 client: reqwest::Client,
13 config: LlmConfig,
14}
15
16#[derive(Clone)]
18pub struct LlmConfig {
19 pub api_key: String,
21 pub base_url: String,
23 pub model: String,
25 pub infer_retries: usize,
27 pub temperature: Option<f64>,
29 pub max_tokens: Option<i64>,
31}
32
33impl LlmConfig {
34 #[cfg(not(target_arch = "wasm32"))]
36 pub fn from_env() -> Self {
37 Self {
38 api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
39 base_url: std::env::var("SAGE_LLM_URL")
40 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
41 model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
42 infer_retries: std::env::var("SAGE_INFER_RETRIES")
43 .ok()
44 .and_then(|s| s.parse().ok())
45 .unwrap_or(DEFAULT_INFER_RETRIES),
46 temperature: std::env::var("SAGE_TEMPERATURE")
47 .ok()
48 .and_then(|s| s.parse().ok()),
49 max_tokens: std::env::var("SAGE_MAX_TOKENS")
50 .ok()
51 .and_then(|s| s.parse().ok()),
52 }
53 }
54
55 #[cfg(target_arch = "wasm32")]
57 pub fn from_env() -> Self {
58 let wasm_config = sage_runtime_web::get_llm_config();
59 Self {
60 api_key: wasm_config.api_key,
61 base_url: wasm_config.base_url,
62 model: wasm_config.model,
63 infer_retries: DEFAULT_INFER_RETRIES,
64 temperature: None,
65 max_tokens: None,
66 }
67 }
68
69 pub fn mock() -> Self {
71 Self {
72 api_key: "mock".to_string(),
73 base_url: "mock".to_string(),
74 model: "mock".to_string(),
75 infer_retries: DEFAULT_INFER_RETRIES,
76 temperature: None,
77 max_tokens: None,
78 }
79 }
80
81 pub fn with_model(model: impl Into<String>) -> Self {
86 let mut config = Self::from_env();
87 config.model = model.into();
88 config
89 }
90
91 #[must_use]
93 pub fn with_temperature(mut self, temp: f64) -> Self {
94 self.temperature = Some(temp);
95 self
96 }
97
98 #[must_use]
100 pub fn with_max_tokens(mut self, tokens: i64) -> Self {
101 self.max_tokens = Some(tokens);
102 self
103 }
104
105 pub fn is_mock(&self) -> bool {
107 self.api_key == "mock"
108 }
109
110 pub fn is_ollama(&self) -> bool {
112 self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
113 }
114}
115
116impl LlmClient {
117 pub fn new(config: LlmConfig) -> Self {
119 Self {
120 client: reqwest::Client::new(),
121 config,
122 }
123 }
124
125 pub fn from_env() -> Self {
127 Self::new(LlmConfig::from_env())
128 }
129
130 pub fn mock() -> Self {
132 Self::new(LlmConfig::mock())
133 }
134
135 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
137 if self.config.is_mock() {
138 return Ok(format!("[Mock LLM response for: {prompt}]"));
139 }
140
141 let request = ChatRequest::new(
142 &self.config.model,
143 vec![ChatMessage {
144 role: "user",
145 content: prompt,
146 }],
147 )
148 .with_config(&self.config);
149
150 self.send_request(&request).await
151 }
152
153 pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
155 where
156 T: DeserializeOwned,
157 {
158 let response = self.infer_string(prompt).await?;
159 parse_json_response(&response)
160 }
161
162 pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
167 where
168 T: DeserializeOwned,
169 {
170 if self.config.is_mock() {
171 return Err(SageError::Llm(
173 "Mock client cannot produce structured output".to_string(),
174 ));
175 }
176
177 let system_prompt = format!(
178 "You are a precise assistant that always responds with valid JSON.\n\
179 You must respond with a JSON object matching this exact schema:\n\n\
180 {schema}\n\n\
181 Respond with JSON only. No explanation, no markdown, no code blocks."
182 );
183
184 let mut last_error: Option<String> = None;
185
186 for attempt in 0..self.config.infer_retries {
187 let response = if attempt == 0 {
188 self.send_structured_request(&system_prompt, prompt, None)
189 .await?
190 } else {
191 let error_feedback = format!(
192 "Your previous response could not be parsed: {}\n\
193 Please try again, responding with valid JSON only.",
194 last_error.as_deref().unwrap_or("unknown error")
195 );
196 self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
197 .await?
198 };
199
200 match parse_json_response::<T>(&response) {
201 Ok(value) => return Ok(value),
202 Err(e) => {
203 last_error = Some(e.to_string());
204 }
206 }
207 }
208
209 Err(SageError::Llm(format!(
210 "Failed to parse structured response after {} attempts: {}",
211 self.config.infer_retries,
212 last_error.unwrap_or_else(|| "unknown error".to_string())
213 )))
214 }
215
216 async fn send_structured_request(
218 &self,
219 system_prompt: &str,
220 user_prompt: &str,
221 error_feedback: Option<&str>,
222 ) -> SageResult<String> {
223 let mut messages = vec![
224 ChatMessage {
225 role: "system",
226 content: system_prompt,
227 },
228 ChatMessage {
229 role: "user",
230 content: user_prompt,
231 },
232 ];
233
234 if let Some(feedback) = error_feedback {
235 messages.push(ChatMessage {
236 role: "user",
237 content: feedback,
238 });
239 }
240
241 let mut request = ChatRequest::new(&self.config.model, messages).with_config(&self.config);
242
243 if self.config.is_ollama() {
245 request = request.with_json_format();
246 }
247
248 self.send_request(&request).await
249 }
250
251 async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
253 let response = self
254 .client
255 .post(format!("{}/chat/completions", self.config.base_url))
256 .header("Authorization", format!("Bearer {}", self.config.api_key))
257 .header("Content-Type", "application/json")
258 .json(request)
259 .send()
260 .await?;
261
262 if !response.status().is_success() {
263 let status = response.status();
264 let body = response.text().await.unwrap_or_default();
265 return Err(SageError::Llm(format!("API error {status}: {body}")));
266 }
267
268 let chat_response: ChatResponse = response.json().await?;
269 let content = chat_response
270 .choices
271 .into_iter()
272 .next()
273 .map(|c| c.message.content)
274 .unwrap_or_default();
275
276 Ok(content)
277 }
278}
279
280fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
282 if let Ok(value) = serde_json::from_str(response) {
284 return Ok(value);
285 }
286
287 let cleaned = response
289 .trim()
290 .strip_prefix("```json")
291 .or_else(|| response.trim().strip_prefix("```"))
292 .unwrap_or(response.trim());
293
294 let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
295
296 serde_json::from_str(cleaned).map_err(|e| {
297 SageError::Llm(format!(
298 "Failed to parse LLM response as {}: {e}\nResponse: {response}",
299 std::any::type_name::<T>()
300 ))
301 })
302}
303
304#[derive(Serialize)]
305struct ChatRequest<'a> {
306 model: &'a str,
307 messages: Vec<ChatMessage<'a>>,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 format: Option<&'a str>,
310 #[serde(skip_serializing_if = "Option::is_none")]
311 temperature: Option<f64>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 max_tokens: Option<i64>,
314}
315
316#[derive(Serialize)]
317struct ChatMessage<'a> {
318 role: &'a str,
319 content: &'a str,
320}
321
322impl<'a> ChatRequest<'a> {
323 fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
324 Self {
325 model,
326 messages,
327 format: None,
328 temperature: None,
329 max_tokens: None,
330 }
331 }
332
333 fn with_json_format(mut self) -> Self {
334 self.format = Some("json");
335 self
336 }
337
338 fn with_config(mut self, config: &LlmConfig) -> Self {
339 self.temperature = config.temperature;
340 self.max_tokens = config.max_tokens;
341 self
342 }
343}
344
345#[derive(Deserialize)]
346struct ChatResponse {
347 choices: Vec<Choice>,
348}
349
350#[derive(Deserialize)]
351struct Choice {
352 message: ResponseMessage,
353}
354
355#[derive(Deserialize)]
356struct ResponseMessage {
357 content: String,
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[tokio::test]
365 async fn mock_client_returns_placeholder() {
366 let client = LlmClient::mock();
367 let response = client.infer_string("test prompt").await.unwrap();
368 assert!(response.contains("Mock LLM response"));
369 assert!(response.contains("test prompt"));
370 }
371
372 #[test]
373 fn parse_json_strips_markdown_fences() {
374 let response = "```json\n{\"value\": 42}\n```";
375 let result: serde_json::Value = parse_json_response(response).unwrap();
376 assert_eq!(result["value"], 42);
377 }
378
379 #[test]
380 fn parse_json_handles_plain_json() {
381 let response = r#"{"name": "test"}"#;
382 let result: serde_json::Value = parse_json_response(response).unwrap();
383 assert_eq!(result["name"], "test");
384 }
385
386 #[test]
387 fn parse_json_handles_generic_code_block() {
388 let response = "```\n{\"x\": 1}\n```";
389 let result: serde_json::Value = parse_json_response(response).unwrap();
390 assert_eq!(result["x"], 1);
391 }
392
393 #[test]
394 fn ollama_detection_localhost() {
395 let config = LlmConfig {
396 api_key: "test".to_string(),
397 base_url: "http://localhost:11434/v1".to_string(),
398 model: "llama2".to_string(),
399 infer_retries: 3,
400 temperature: None,
401 max_tokens: None,
402 };
403 assert!(config.is_ollama());
404 }
405
406 #[test]
407 fn ollama_detection_127() {
408 let config = LlmConfig {
409 api_key: "test".to_string(),
410 base_url: "http://127.0.0.1:11434/v1".to_string(),
411 model: "llama2".to_string(),
412 infer_retries: 3,
413 temperature: None,
414 max_tokens: None,
415 };
416 assert!(config.is_ollama());
417 }
418
419 #[test]
420 fn not_ollama_for_openai() {
421 let config = LlmConfig {
422 api_key: "test".to_string(),
423 base_url: "https://api.openai.com/v1".to_string(),
424 model: "gpt-4".to_string(),
425 infer_retries: 3,
426 temperature: None,
427 max_tokens: None,
428 };
429 assert!(!config.is_ollama());
430 }
431
432 #[test]
433 fn chat_request_json_format() {
434 let request = ChatRequest::new("model", vec![]).with_json_format();
435 let json = serde_json::to_string(&request).unwrap();
436 assert!(json.contains(r#""format":"json""#));
437 }
438
439 #[test]
440 fn chat_request_no_format_by_default() {
441 let request = ChatRequest::new("model", vec![]);
442 let json = serde_json::to_string(&request).unwrap();
443 assert!(!json.contains("format"));
444 }
445
446 #[tokio::test]
447 async fn infer_structured_fails_on_mock() {
448 let client = LlmClient::mock();
449 let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
450 assert!(result.is_err());
451 assert!(result.unwrap_err().to_string().contains("Mock client"));
452 }
453}