Skip to main content

synwire_core/output_parsers/
tools.rs

1//! Tools output parser that extracts tool calls from messages.
2
3use crate::error::{ParseError, SynwireError};
4use crate::messages::{Message, ToolCall};
5use crate::output_parsers::OutputParser;
6
7/// Parser that extracts tool calls from an AI message.
8///
9/// Can operate in two modes:
10/// - **Message-based**: Extracts `ToolCall` values directly from an AI message variant.
11/// - **Text-based**: Parses a JSON array of tool calls from raw text (via the `OutputParser` trait).
12///
13/// # Examples
14///
15/// ```
16/// use synwire_core::output_parsers::ToolsOutputParser;
17/// use synwire_core::messages::Message;
18///
19/// let parser = ToolsOutputParser;
20/// let msg = Message::ai("No tools needed");
21/// let calls = parser.parse_message(&msg).unwrap();
22/// assert!(calls.is_empty());
23/// ```
24pub struct ToolsOutputParser;
25
26impl ToolsOutputParser {
27    /// Extract tool calls from a message.
28    ///
29    /// Returns the tool calls if the message is an AI message, or an empty
30    /// vector for any other message type.
31    ///
32    /// # Errors
33    ///
34    /// This method currently does not produce errors, but returns `Result`
35    /// for forward compatibility.
36    pub fn parse_message(&self, message: &Message) -> Result<Vec<ToolCall>, SynwireError> {
37        match message {
38            Message::AI { tool_calls, .. } => Ok(tool_calls.clone()),
39            _ => Ok(Vec::new()),
40        }
41    }
42}
43
44impl OutputParser for ToolsOutputParser {
45    type Output = Vec<ToolCall>;
46
47    fn parse(&self, text: &str) -> Result<Vec<ToolCall>, SynwireError> {
48        serde_json::from_str(text).map_err(|e| {
49            SynwireError::from(ParseError::ParseFailed {
50                message: format!("Failed to parse tool calls: {e}"),
51            })
52        })
53    }
54}
55
56#[cfg(test)]
57#[allow(clippy::unwrap_used)]
58mod tests {
59    use std::collections::HashMap;
60
61    use super::*;
62
63    #[test]
64    fn test_tools_parser_extracts_calls() {
65        let parser = ToolsOutputParser;
66        let msg = Message::AI {
67            id: None,
68            name: None,
69            content: crate::messages::MessageContent::Text("Calling tool".into()),
70            tool_calls: vec![ToolCall {
71                id: "tc_1".into(),
72                name: "search".into(),
73                arguments: {
74                    let mut m = HashMap::new();
75                    let _ = m.insert("query".into(), serde_json::Value::String("rust".into()));
76                    m
77                },
78            }],
79            invalid_tool_calls: Vec::new(),
80            usage: None,
81            response_metadata: None,
82            additional_kwargs: HashMap::new(),
83        };
84        let calls = parser.parse_message(&msg).unwrap();
85        assert_eq!(calls.len(), 1);
86        assert_eq!(calls[0].name, "search");
87    }
88
89    #[test]
90    fn test_tools_parser_non_ai_message() {
91        let parser = ToolsOutputParser;
92        let msg = Message::human("Hello");
93        let calls = parser.parse_message(&msg).unwrap();
94        assert!(calls.is_empty());
95    }
96
97    #[test]
98    fn test_tools_parser_from_text() {
99        let parser = ToolsOutputParser;
100        let json = r#"[{"id": "tc_1", "name": "search", "arguments": {"query": "test"}}]"#;
101        let calls = parser.parse(json).unwrap();
102        assert_eq!(calls.len(), 1);
103        assert_eq!(calls[0].id, "tc_1");
104    }
105
106    #[test]
107    fn test_tools_parser_invalid_text() {
108        let parser = ToolsOutputParser;
109        let result = parser.parse("not json");
110        assert!(result.is_err());
111    }
112}