symbi_runtime/reasoning/providers/
slm.rs1use crate::models::runners::{ExecutionOptions, SlmRunner};
9use crate::reasoning::conversation::Conversation;
10use crate::reasoning::inference::*;
11use async_trait::async_trait;
12use std::sync::Arc;
13
14pub struct SlmInferenceProvider {
16 runner: Arc<dyn SlmRunner>,
17 model_name: String,
18}
19
20impl SlmInferenceProvider {
21 pub fn new(runner: Arc<dyn SlmRunner>, model_name: impl Into<String>) -> Self {
23 Self {
24 runner,
25 model_name: model_name.into(),
26 }
27 }
28
29 fn build_prompt(conversation: &Conversation, options: &InferenceOptions) -> String {
32 let mut parts = Vec::new();
33
34 if let Some(sys) = conversation.system_message() {
36 parts.push(format!("### System\n{}", sys.content));
37 }
38
39 if !options.tool_definitions.is_empty() {
41 let mut tool_section = String::from("\n### Available Tools\nYou have access to the following tools. To call a tool, respond with a JSON object in this exact format:\n```json\n{\"tool_calls\": [{\"name\": \"<tool_name>\", \"arguments\": {<args>}}]}\n```\n\nTools:\n");
42 for td in &options.tool_definitions {
43 tool_section.push_str(&format!(
44 "- **{}**: {}\n Parameters: {}\n",
45 td.name,
46 td.description,
47 serde_json::to_string_pretty(&td.parameters).unwrap_or_default()
48 ));
49 }
50 tool_section
51 .push_str("\nIf you don't need to call any tools, respond with plain text.\n");
52 parts.push(tool_section);
53 }
54
55 match &options.response_format {
57 ResponseFormat::Text => {}
58 ResponseFormat::JsonObject => {
59 parts.push(
60 "\n### Response Format\nYou MUST respond with a valid JSON object. Do not include any text outside the JSON.".into(),
61 );
62 }
63 ResponseFormat::JsonSchema { schema, .. } => {
64 parts.push(format!(
65 "\n### Response Format\nYou MUST respond with a valid JSON object conforming to this schema:\n```json\n{}\n```\nDo not include any text outside the JSON.",
66 serde_json::to_string_pretty(schema).unwrap_or_default()
67 ));
68 }
69 }
70
71 for msg in conversation.messages() {
73 match msg.role {
74 crate::reasoning::conversation::MessageRole::System => continue, crate::reasoning::conversation::MessageRole::User => {
76 parts.push(format!("\n### User\n{}", msg.content));
77 }
78 crate::reasoning::conversation::MessageRole::Assistant => {
79 if !msg.tool_calls.is_empty() {
80 let tc_json: Vec<serde_json::Value> = msg
81 .tool_calls
82 .iter()
83 .map(|tc| {
84 serde_json::json!({
85 "name": tc.name,
86 "arguments": serde_json::from_str::<serde_json::Value>(&tc.arguments).unwrap_or(serde_json::json!({}))
87 })
88 })
89 .collect();
90 parts.push(format!(
91 "\n### Assistant\n```json\n{{\"tool_calls\": {}}}\n```",
92 serde_json::to_string(&tc_json).unwrap_or_default()
93 ));
94 } else {
95 parts.push(format!("\n### Assistant\n{}", msg.content));
96 }
97 }
98 crate::reasoning::conversation::MessageRole::Tool => {
99 let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
100 parts.push(format!(
101 "\n### Tool Result ({})\n{}",
102 tool_name, msg.content
103 ));
104 }
105 }
106 }
107
108 parts.push("\n### Assistant\n".into());
109 parts.join("\n")
110 }
111
112 fn extract_tool_calls(text: &str) -> Vec<ToolCallRequest> {
117 let json_text = strip_markdown_fences(text);
119
120 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_text) {
121 if let Some(calls) = parsed.get("tool_calls").and_then(|c| c.as_array()) {
122 return calls
123 .iter()
124 .enumerate()
125 .filter_map(|(i, call)| {
126 let name = call.get("name")?.as_str()?.to_string();
127 let arguments = call
128 .get("arguments")
129 .map(|a| serde_json::to_string(a).unwrap_or_default())
130 .unwrap_or_else(|| "{}".into());
131 Some(ToolCallRequest {
132 id: format!("slm_call_{}", i),
133 name,
134 arguments,
135 })
136 })
137 .collect();
138 }
139 }
140
141 Vec::new()
142 }
143}
144
145pub fn strip_markdown_fences(text: &str) -> String {
147 let trimmed = text.trim();
148
149 if let Some(rest) = trimmed.strip_prefix("```") {
151 let content = if let Some(idx) = rest.find('\n') {
153 &rest[idx + 1..]
154 } else {
155 rest
156 };
157 if let Some(stripped) = content.strip_suffix("```") {
158 return stripped.trim().to_string();
159 }
160 return content.trim().to_string();
161 }
162
163 trimmed.to_string()
164}
165
166#[async_trait]
167impl InferenceProvider for SlmInferenceProvider {
168 async fn complete(
169 &self,
170 conversation: &Conversation,
171 options: &InferenceOptions,
172 ) -> Result<InferenceResponse, InferenceError> {
173 let prompt = Self::build_prompt(conversation, options);
174
175 let exec_options = ExecutionOptions {
176 timeout: Some(std::time::Duration::from_secs(60)),
177 temperature: Some(options.temperature),
178 max_tokens: Some(options.max_tokens),
179 custom_parameters: Default::default(),
180 };
181
182 let result = self
183 .runner
184 .execute(&prompt, Some(exec_options))
185 .await
186 .map_err(|e| InferenceError::Provider(format!("SLM execution failed: {}", e)))?;
187
188 let response_text = result.response.clone();
189 let tool_calls = Self::extract_tool_calls(&response_text);
190
191 let finish_reason = if !tool_calls.is_empty() {
192 FinishReason::ToolCalls
193 } else {
194 FinishReason::Stop
195 };
196
197 let content = if !tool_calls.is_empty() {
198 String::new()
201 } else {
202 response_text
203 };
204
205 let usage = Usage {
206 prompt_tokens: result.metadata.input_tokens.unwrap_or(0),
207 completion_tokens: result.metadata.output_tokens.unwrap_or(0),
208 total_tokens: result
209 .metadata
210 .input_tokens
211 .unwrap_or(0)
212 .saturating_add(result.metadata.output_tokens.unwrap_or(0)),
213 };
214
215 Ok(InferenceResponse {
216 content,
217 tool_calls,
218 finish_reason,
219 usage,
220 model: self.model_name.clone(),
221 })
222 }
223
224 fn provider_name(&self) -> &str {
225 "slm"
226 }
227
228 fn default_model(&self) -> &str {
229 &self.model_name
230 }
231
232 fn supports_native_tools(&self) -> bool {
233 false
234 }
235
236 fn supports_structured_output(&self) -> bool {
237 false
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::reasoning::conversation::ConversationMessage;
245
246 #[test]
247 fn test_strip_markdown_fences_json() {
248 let input = "```json\n{\"tool_calls\": [{\"name\": \"search\", \"arguments\": {\"q\": \"test\"}}]}\n```";
249 let result = strip_markdown_fences(input);
250 assert!(result.starts_with('{'));
251 assert!(result.ends_with('}'));
252 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
253 assert!(parsed.get("tool_calls").is_some());
254 }
255
256 #[test]
257 fn test_strip_markdown_fences_plain() {
258 let input = "```\n{\"key\": \"value\"}\n```";
259 let result = strip_markdown_fences(input);
260 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
261 assert_eq!(parsed["key"], "value");
262 }
263
264 #[test]
265 fn test_strip_markdown_fences_no_fences() {
266 let input = "{\"key\": \"value\"}";
267 let result = strip_markdown_fences(input);
268 assert_eq!(result, input);
269 }
270
271 #[test]
272 fn test_extract_tool_calls_valid() {
273 let text = r#"```json
274{"tool_calls": [{"name": "web_search", "arguments": {"query": "rust"}}]}
275```"#;
276 let calls = SlmInferenceProvider::extract_tool_calls(text);
277 assert_eq!(calls.len(), 1);
278 assert_eq!(calls[0].name, "web_search");
279 assert_eq!(calls[0].id, "slm_call_0");
280 }
281
282 #[test]
283 fn test_extract_tool_calls_no_tools() {
284 let text = "I don't need any tools for this. The answer is 42.";
285 let calls = SlmInferenceProvider::extract_tool_calls(text);
286 assert!(calls.is_empty());
287 }
288
289 #[test]
290 fn test_extract_tool_calls_multiple() {
291 let text = r#"{"tool_calls": [
292 {"name": "search", "arguments": {"q": "a"}},
293 {"name": "read", "arguments": {"path": "/tmp/x"}}
294 ]}"#;
295 let calls = SlmInferenceProvider::extract_tool_calls(text);
296 assert_eq!(calls.len(), 2);
297 assert_eq!(calls[0].name, "search");
298 assert_eq!(calls[1].name, "read");
299 }
300
301 #[test]
302 fn test_build_prompt_basic() {
303 let mut conv = Conversation::with_system("You are helpful.");
304 conv.push(ConversationMessage::user("What is 2+2?"));
305
306 let opts = InferenceOptions::default();
307 let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
308
309 assert!(prompt.contains("### System"));
310 assert!(prompt.contains("You are helpful."));
311 assert!(prompt.contains("### User"));
312 assert!(prompt.contains("What is 2+2?"));
313 assert!(prompt.contains("### Assistant"));
314 }
315
316 #[test]
317 fn test_build_prompt_with_tools() {
318 let conv = Conversation::with_system("Agent");
319 let opts = InferenceOptions {
320 tool_definitions: vec![ToolDefinition {
321 name: "search".into(),
322 description: "Search the web".into(),
323 parameters: serde_json::json!({"type": "object", "properties": {"q": {"type": "string"}}}),
324 }],
325 ..Default::default()
326 };
327
328 let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
329 assert!(prompt.contains("### Available Tools"));
330 assert!(prompt.contains("search"));
331 assert!(prompt.contains("Search the web"));
332 assert!(prompt.contains("tool_calls"));
333 }
334
335 #[test]
336 fn test_build_prompt_with_json_schema() {
337 let conv = Conversation::with_system("Agent");
338 let opts = InferenceOptions {
339 response_format: ResponseFormat::JsonSchema {
340 schema: serde_json::json!({"type": "object", "properties": {"answer": {"type": "string"}}}),
341 name: Some("Answer".into()),
342 },
343 ..Default::default()
344 };
345
346 let prompt = SlmInferenceProvider::build_prompt(&conv, &opts);
347 assert!(prompt.contains("### Response Format"));
348 assert!(prompt.contains("JSON object conforming to this schema"));
349 }
350}