Skip to main content

symbi_runtime/reasoning/providers/
slm.rs

1//! SLM inference provider
2//!
3//! Wraps the existing `SlmRunner` to implement `InferenceProvider`.
4//! Since SLMs don't natively support tool calling or structured output,
5//! this provider injects tool definitions and JSON schemas into the system
6//! prompt and parses structured JSON from the text output.
7
8use crate::models::runners::{ExecutionOptions, SlmRunner};
9use crate::reasoning::conversation::Conversation;
10use crate::reasoning::inference::*;
11use async_trait::async_trait;
12use std::sync::Arc;
13
14/// SLM inference provider wrapping a `SlmRunner`.
15pub struct SlmInferenceProvider {
16    runner: Arc<dyn SlmRunner>,
17    model_name: String,
18}
19
20impl SlmInferenceProvider {
21    /// Create a new SLM provider from an existing runner.
22    pub fn new(runner: Arc<dyn SlmRunner>, model_name: impl Into<String>) -> Self {
23        Self {
24            runner,
25            model_name: model_name.into(),
26        }
27    }
28
29    /// Build a single prompt string from a conversation, injecting tool
30    /// definitions and response format instructions into the system prompt.
31    fn build_prompt(conversation: &Conversation, options: &InferenceOptions) -> String {
32        let mut parts = Vec::new();
33
34        // Start with system message, augmented with tool/format instructions
35        if let Some(sys) = conversation.system_message() {
36            parts.push(format!("### System\n{}", sys.content));
37        }
38
39        // Inject tool definitions into the prompt
40        if !options.tool_definitions.is_empty() {
41            let mut tool_section = String::from("\n### Available Tools\nYou have access to the following tools. To call a tool, respond with a JSON object in this exact format:\n```json\n{\"tool_calls\": [{\"name\": \"<tool_name>\", \"arguments\": {<args>}}]}\n```\n\nTools:\n");
42            for td in &options.tool_definitions {
43                tool_section.push_str(&format!(
44                    "- **{}**: {}\n  Parameters: {}\n",
45                    td.name,
46                    td.description,
47                    serde_json::to_string_pretty(&td.parameters).unwrap_or_default()
48                ));
49            }
50            tool_section
51                .push_str("\nIf you don't need to call any tools, respond with plain text.\n");
52            parts.push(tool_section);
53        }
54
55        // Inject response format instructions
56        match &options.response_format {
57            ResponseFormat::Text => {}
58            ResponseFormat::JsonObject => {
59                parts.push(
60                    "\n### Response Format\nYou MUST respond with a valid JSON object. Do not include any text outside the JSON.".into(),
61                );
62            }
63            ResponseFormat::JsonSchema { schema, .. } => {
64                parts.push(format!(
65                    "\n### Response Format\nYou MUST respond with a valid JSON object conforming to this schema:\n```json\n{}\n```\nDo not include any text outside the JSON.",
66                    serde_json::to_string_pretty(schema).unwrap_or_default()
67                ));
68            }
69        }
70
71        // Add conversation history
72        for msg in conversation.messages() {
73            match msg.role {
74                crate::reasoning::conversation::MessageRole::System => continue, // Already handled
75                crate::reasoning::conversation::MessageRole::User => {
76                    parts.push(format!("\n### User\n{}", msg.content));
77                }
78                crate::reasoning::conversation::MessageRole::Assistant => {
79                    if !msg.tool_calls.is_empty() {
80                        let tc_json: Vec<serde_json::Value> = msg
81                            .tool_calls
82                            .iter()
83                            .map(|tc| {
84                                serde_json::json!({
85                                    "name": tc.name,
86                                    "arguments": serde_json::from_str::<serde_json::Value>(&tc.arguments).unwrap_or(serde_json::json!({}))
87                                })
88                            })
89                            .collect();
90                        parts.push(format!(
91                            "\n### Assistant\n```json\n{{\"tool_calls\": {}}}\n```",
92                            serde_json::to_string(&tc_json).unwrap_or_default()
93                        ));
94                    } else {
95                        parts.push(format!("\n### Assistant\n{}", msg.content));
96                    }
97                }
98                crate::reasoning::conversation::MessageRole::Tool => {
99                    let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
100                    parts.push(format!(
101                        "\n### Tool Result ({})\n{}",
102                        tool_name, msg.content
103                    ));
104                }
105            }
106        }
107
108        parts.push("\n### Assistant\n".into());
109        parts.join("\n")
110    }
111
112    /// Attempt to extract tool calls from the SLM's text response.
113    ///
114    /// Looks for JSON blocks containing a `tool_calls` array, either bare
115    /// or wrapped in markdown code fences.
116    fn extract_tool_calls(text: &str) -> Vec<ToolCallRequest> {
117        // Try to find JSON with tool_calls in the response
118        let json_text = strip_markdown_fences(text);
119
120        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_text) {
121            if let Some(calls) = parsed.get("tool_calls").and_then(|c| c.as_array()) {
122                return calls
123                    .iter()
124                    .enumerate()
125                    .filter_map(|(i, call)| {
126                        let name = call.get("name")?.as_str()?.to_string();
127                        let arguments = call
128                            .get("arguments")
129                            .map(|a| serde_json::to_string(a).unwrap_or_default())
130                            .unwrap_or_else(|| "{}".into());
131                        Some(ToolCallRequest {
132                            id: format!("slm_call_{}", i),
133                            name,
134                            arguments,
135                        })
136                    })
137                    .collect();
138            }
139        }
140
141        Vec::new()
142    }
143}
144
145/// Strip markdown code fences from a string, returning the inner content.
146pub fn strip_markdown_fences(text: &str) -> String {
147    let trimmed = text.trim();
148
149    // Handle ```json ... ``` or ``` ... ```
150    if let Some(rest) = trimmed.strip_prefix("```") {
151        // Skip the language tag on the first line
152        let content = if let Some(idx) = rest.find('\n') {
153            &rest[idx + 1..]
154        } else {
155            rest
156        };
157        if let Some(stripped) = content.strip_suffix("```") {
158            return stripped.trim().to_string();
159        }
160        return content.trim().to_string();
161    }
162
163    trimmed.to_string()
164}
165
166#[async_trait]
167impl InferenceProvider for SlmInferenceProvider {
168    async fn complete(
169        &self,
170        conversation: &Conversation,
171        options: &InferenceOptions,
172    ) -> Result<InferenceResponse, InferenceError> {
173        let prompt = Self::build_prompt(conversation, options);
174
175        let exec_options = ExecutionOptions {
176            timeout: Some(std::time::Duration::from_secs(60)),
177            temperature: Some(options.temperature),
178            max_tokens: Some(options.max_tokens),
179            custom_parameters: Default::default(),
180        };
181
182        let result = self
183            .runner
184            .execute(&prompt, Some(exec_options))
185            .await
186            .map_err(|e| InferenceError::Provider(format!("SLM execution failed: {}", e)))?;
187
188        let response_text = result.response.clone();
189        let tool_calls = Self::extract_tool_calls(&response_text);
190
191        let finish_reason = if !tool_calls.is_empty() {
192            FinishReason::ToolCalls
193        } else {
194            FinishReason::Stop
195        };
196
197        let content = if !tool_calls.is_empty() {
198            // If we extracted tool calls, the text content is whatever remains
199            // outside the JSON block (may be empty)
200            String::new()
201        } else {
202            response_text
203        };
204
205        let usage = Usage {
206            prompt_tokens: result.metadata.input_tokens.unwrap_or(0),
207            completion_tokens: result.metadata.output_tokens.unwrap_or(0),
208            total_tokens: result
209                .metadata
210                .input_tokens
211                .unwrap_or(0)
212                .saturating_add(result.metadata.output_tokens.unwrap_or(0)),
213        };
214
215        Ok(InferenceResponse {
216            content,
217            tool_calls,
218            finish_reason,
219            usage,
220            model: self.model_name.clone(),
221        })
222    }
223
224    fn provider_name(&self) -> &str {
225        "slm"
226    }
227
228    fn default_model(&self) -> &str {
229        &self.model_name
230    }
231
232    fn supports_native_tools(&self) -> bool {
233        false
234    }
235
236    fn supports_structured_output(&self) -> bool {
237        false
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::reasoning::conversation::ConversationMessage;
245
246    #[test]
247    fn test_strip_markdown_fences_json() {
248        let input = "```json\n{\"tool_calls\": [{\"name\": \"search\", \"arguments\": {\"q\": \"test\"}}]}\n```";
249        let result = strip_markdown_fences(input);
250        assert!(result.starts_with('{'));
251        assert!(result.ends_with('}'));
252        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
253        assert!(parsed.get("tool_calls").is_some());
254    }
255
256    #[test]
257    fn test_strip_markdown_fences_plain() {
258        let input = "```\n{\"key\": \"value\"}\n```";
259        let result = strip_markdown_fences(input);
260        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
261        assert_eq!(parsed["key"], "value");
262    }
263
264    #[test]
265    fn test_strip_markdown_fences_no_fences() {
266        let input = "{\"key\": \"value\"}";
267        let result = strip_markdown_fences(input);
268        assert_eq!(result, input);
269    }
270
271    #[test]
272    fn test_extract_tool_calls_valid() {
273        let text = r#"```json
274{"tool_calls": [{"name": "web_search", "arguments": {"query": "rust"}}]}
275```"#;
276        let calls = SlmInferenceProvider::extract_tool_calls(text);
277        assert_eq!(calls.len(), 1);
278        assert_eq!(calls[0].name, "web_search");
279        assert_eq!(calls[0].id, "slm_call_0");
280    }
281
282    #[test]
283    fn test_extract_tool_calls_no_tools() {
284        let text = "I don't need any tools for this. The answer is 42.";
285        let calls = SlmInferenceProvider::extract_tool_calls(text);
286        assert!(calls.is_empty());
287    }
288
289    #[test]
290    fn test_extract_tool_calls_multiple() {
291        let text = r#"{"tool_calls": [
292            {"name": "search", "arguments": {"q": "a"}},
293            {"name": "read", "arguments": {"path": "/tmp/x"}}
294        ]}"#;
295        let calls = SlmInferenceProvider::extract_tool_calls(text);
296        assert_eq!(calls.len(), 2);
297        assert_eq!(calls[0].name, "search");
298        assert_eq!(calls[1].name, "read");
299    }
300
301    #[test]
302    fn test_build_prompt_basic() {
303        let mut conv = Conversation::with_system("You are helpful.");
304        conv.push(ConversationMessage::user("What is 2+2?"));
305
306        let opts = InferenceOptions::default();
307        let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
308
309        assert!(prompt.contains("### System"));
310        assert!(prompt.contains("You are helpful."));
311        assert!(prompt.contains("### User"));
312        assert!(prompt.contains("What is 2+2?"));
313        assert!(prompt.contains("### Assistant"));
314    }
315
316    #[test]
317    fn test_build_prompt_with_tools() {
318        let conv = Conversation::with_system("Agent");
319        let opts = InferenceOptions {
320            tool_definitions: vec![ToolDefinition {
321                name: "search".into(),
322                description: "Search the web".into(),
323                parameters: serde_json::json!({"type": "object", "properties": {"q": {"type": "string"}}}),
324            }],
325            ..Default::default()
326        };
327
328        let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
329        assert!(prompt.contains("### Available Tools"));
330        assert!(prompt.contains("search"));
331        assert!(prompt.contains("Search the web"));
332        assert!(prompt.contains("tool_calls"));
333    }
334
335    #[test]
336    fn test_build_prompt_with_json_schema() {
337        let conv = Conversation::with_system("Agent");
338        let opts = InferenceOptions {
339            response_format: ResponseFormat::JsonSchema {
340                schema: serde_json::json!({"type": "object", "properties": {"answer": {"type": "string"}}}),
341                name: Some("Answer".into()),
342            },
343            ..Default::default()
344        };
345
346        let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
347        assert!(prompt.contains("### Response Format"));
348        assert!(prompt.contains("JSON object conforming to this schema"));
349    }
350}