Skip to main content

sage_runtime/
llm.rs

1//! LLM client for inference calls.
2
3use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6/// Default number of retries for structured inference.
7const DEFAULT_INFER_RETRIES: usize = 3;
8
9/// Client for making LLM inference calls.
10#[derive(Clone)]
11pub struct LlmClient {
12    client: reqwest::Client,
13    config: LlmConfig,
14}
15
16/// Configuration for the LLM client.
17#[derive(Clone)]
18pub struct LlmConfig {
19    /// API key for authentication.
20    pub api_key: String,
21    /// Base URL for the API.
22    pub base_url: String,
23    /// Model to use.
24    pub model: String,
25    /// Max retries for structured inference.
26    pub infer_retries: usize,
27}
28
29impl LlmConfig {
30    /// Create a config from environment variables.
31    pub fn from_env() -> Self {
32        Self {
33            api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
34            base_url: std::env::var("SAGE_LLM_URL")
35                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
36            model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
37            infer_retries: std::env::var("SAGE_INFER_RETRIES")
38                .ok()
39                .and_then(|s| s.parse().ok())
40                .unwrap_or(DEFAULT_INFER_RETRIES),
41        }
42    }
43
44    /// Create a mock config for testing.
45    pub fn mock() -> Self {
46        Self {
47            api_key: "mock".to_string(),
48            base_url: "mock".to_string(),
49            model: "mock".to_string(),
50            infer_retries: DEFAULT_INFER_RETRIES,
51        }
52    }
53
54    /// Check if this is a mock configuration.
55    pub fn is_mock(&self) -> bool {
56        self.api_key == "mock"
57    }
58
59    /// Check if the base URL points to a local Ollama instance.
60    pub fn is_ollama(&self) -> bool {
61        self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
62    }
63}
64
65impl LlmClient {
66    /// Create a new LLM client with the given configuration.
67    pub fn new(config: LlmConfig) -> Self {
68        Self {
69            client: reqwest::Client::new(),
70            config,
71        }
72    }
73
74    /// Create a client from environment variables.
75    pub fn from_env() -> Self {
76        Self::new(LlmConfig::from_env())
77    }
78
79    /// Create a mock client for testing.
80    pub fn mock() -> Self {
81        Self::new(LlmConfig::mock())
82    }
83
84    /// Call the LLM with a prompt and return the raw string response.
85    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
86        if self.config.is_mock() {
87            return Ok(format!("[Mock LLM response for: {prompt}]"));
88        }
89
90        let request = ChatRequest::new(
91            &self.config.model,
92            vec![ChatMessage {
93                role: "user",
94                content: prompt,
95            }],
96        );
97
98        self.send_request(&request).await
99    }
100
101    /// Call the LLM with a prompt and parse the response as the given type.
102    pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
103    where
104        T: DeserializeOwned,
105    {
106        let response = self.infer_string(prompt).await?;
107        parse_json_response(&response)
108    }
109
110    /// Call the LLM with schema-injected prompt engineering for structured output.
111    ///
112    /// The schema is injected as a system message, and the runtime retries up to
113    /// `SAGE_INFER_RETRIES` times (default 3) on parse failure.
114    pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
115    where
116        T: DeserializeOwned,
117    {
118        if self.config.is_mock() {
119            // For mock mode, return an error since we can't produce valid structured output
120            return Err(SageError::Llm(
121                "Mock client cannot produce structured output".to_string(),
122            ));
123        }
124
125        let system_prompt = format!(
126            "You are a precise assistant that always responds with valid JSON.\n\
127             You must respond with a JSON object matching this exact schema:\n\n\
128             {schema}\n\n\
129             Respond with JSON only. No explanation, no markdown, no code blocks."
130        );
131
132        let mut last_error: Option<String> = None;
133
134        for attempt in 0..self.config.infer_retries {
135            let response = if attempt == 0 {
136                self.send_structured_request(&system_prompt, prompt, None)
137                    .await?
138            } else {
139                let error_feedback = format!(
140                    "Your previous response could not be parsed: {}\n\
141                     Please try again, responding with valid JSON only.",
142                    last_error.as_deref().unwrap_or("unknown error")
143                );
144                self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
145                    .await?
146            };
147
148            match parse_json_response::<T>(&response) {
149                Ok(value) => return Ok(value),
150                Err(e) => {
151                    last_error = Some(e.to_string());
152                    // Continue to next retry
153                }
154            }
155        }
156
157        Err(SageError::Llm(format!(
158            "Failed to parse structured response after {} attempts: {}",
159            self.config.infer_retries,
160            last_error.unwrap_or_else(|| "unknown error".to_string())
161        )))
162    }
163
164    /// Send a structured inference request with optional error feedback.
165    async fn send_structured_request(
166        &self,
167        system_prompt: &str,
168        user_prompt: &str,
169        error_feedback: Option<&str>,
170    ) -> SageResult<String> {
171        let mut messages = vec![
172            ChatMessage {
173                role: "system",
174                content: system_prompt,
175            },
176            ChatMessage {
177                role: "user",
178                content: user_prompt,
179            },
180        ];
181
182        if let Some(feedback) = error_feedback {
183            messages.push(ChatMessage {
184                role: "user",
185                content: feedback,
186            });
187        }
188
189        let mut request = ChatRequest::new(&self.config.model, messages);
190
191        // Add format: json hint for Ollama
192        if self.config.is_ollama() {
193            request = request.with_json_format();
194        }
195
196        self.send_request(&request).await
197    }
198
199    /// Send a chat request and return the response content.
200    async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
201        let response = self
202            .client
203            .post(format!("{}/chat/completions", self.config.base_url))
204            .header("Authorization", format!("Bearer {}", self.config.api_key))
205            .header("Content-Type", "application/json")
206            .json(request)
207            .send()
208            .await?;
209
210        if !response.status().is_success() {
211            let status = response.status();
212            let body = response.text().await.unwrap_or_default();
213            return Err(SageError::Llm(format!("API error {status}: {body}")));
214        }
215
216        let chat_response: ChatResponse = response.json().await?;
217        let content = chat_response
218            .choices
219            .into_iter()
220            .next()
221            .map(|c| c.message.content)
222            .unwrap_or_default();
223
224        Ok(content)
225    }
226}
227
228/// Strip markdown code fences from a response and parse as JSON.
229fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
230    // Try to parse as-is first
231    if let Ok(value) = serde_json::from_str(response) {
232        return Ok(value);
233    }
234
235    // Strip markdown code blocks if present
236    let cleaned = response
237        .trim()
238        .strip_prefix("```json")
239        .or_else(|| response.trim().strip_prefix("```"))
240        .unwrap_or(response.trim());
241
242    let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
243
244    serde_json::from_str(cleaned).map_err(|e| {
245        SageError::Llm(format!(
246            "Failed to parse LLM response as {}: {e}\nResponse: {response}",
247            std::any::type_name::<T>()
248        ))
249    })
250}
251
252#[derive(Serialize)]
253struct ChatRequest<'a> {
254    model: &'a str,
255    messages: Vec<ChatMessage<'a>>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    format: Option<&'a str>,
258}
259
260#[derive(Serialize)]
261struct ChatMessage<'a> {
262    role: &'a str,
263    content: &'a str,
264}
265
266impl<'a> ChatRequest<'a> {
267    fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
268        Self {
269            model,
270            messages,
271            format: None,
272        }
273    }
274
275    fn with_json_format(mut self) -> Self {
276        self.format = Some("json");
277        self
278    }
279}
280
281#[derive(Deserialize)]
282struct ChatResponse {
283    choices: Vec<Choice>,
284}
285
286#[derive(Deserialize)]
287struct Choice {
288    message: ResponseMessage,
289}
290
291#[derive(Deserialize)]
292struct ResponseMessage {
293    content: String,
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[tokio::test]
301    async fn mock_client_returns_placeholder() {
302        let client = LlmClient::mock();
303        let response = client.infer_string("test prompt").await.unwrap();
304        assert!(response.contains("Mock LLM response"));
305        assert!(response.contains("test prompt"));
306    }
307
308    #[test]
309    fn parse_json_strips_markdown_fences() {
310        let response = "```json\n{\"value\": 42}\n```";
311        let result: serde_json::Value = parse_json_response(response).unwrap();
312        assert_eq!(result["value"], 42);
313    }
314
315    #[test]
316    fn parse_json_handles_plain_json() {
317        let response = r#"{"name": "test"}"#;
318        let result: serde_json::Value = parse_json_response(response).unwrap();
319        assert_eq!(result["name"], "test");
320    }
321
322    #[test]
323    fn parse_json_handles_generic_code_block() {
324        let response = "```\n{\"x\": 1}\n```";
325        let result: serde_json::Value = parse_json_response(response).unwrap();
326        assert_eq!(result["x"], 1);
327    }
328
329    #[test]
330    fn ollama_detection_localhost() {
331        let config = LlmConfig {
332            api_key: "test".to_string(),
333            base_url: "http://localhost:11434/v1".to_string(),
334            model: "llama2".to_string(),
335            infer_retries: 3,
336        };
337        assert!(config.is_ollama());
338    }
339
340    #[test]
341    fn ollama_detection_127() {
342        let config = LlmConfig {
343            api_key: "test".to_string(),
344            base_url: "http://127.0.0.1:11434/v1".to_string(),
345            model: "llama2".to_string(),
346            infer_retries: 3,
347        };
348        assert!(config.is_ollama());
349    }
350
351    #[test]
352    fn not_ollama_for_openai() {
353        let config = LlmConfig {
354            api_key: "test".to_string(),
355            base_url: "https://api.openai.com/v1".to_string(),
356            model: "gpt-4".to_string(),
357            infer_retries: 3,
358        };
359        assert!(!config.is_ollama());
360    }
361
362    #[test]
363    fn chat_request_json_format() {
364        let request = ChatRequest::new("model", vec![]).with_json_format();
365        let json = serde_json::to_string(&request).unwrap();
366        assert!(json.contains(r#""format":"json""#));
367    }
368
369    #[test]
370    fn chat_request_no_format_by_default() {
371        let request = ChatRequest::new("model", vec![]);
372        let json = serde_json::to_string(&request).unwrap();
373        assert!(!json.contains("format"));
374    }
375
376    #[tokio::test]
377    async fn infer_structured_fails_on_mock() {
378        let client = LlmClient::mock();
379        let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
380        assert!(result.is_err());
381        assert!(result.unwrap_err().to_string().contains("Mock client"));
382    }
383}