Skip to main content

tycode_core/chat/
xml_tool_parser.rs

1use crate::ai::ToolUseData;
2use anyhow::{bail, Result};
3use serde_json::Value;
4use tracing::debug;
5use uuid::Uuid;
6
7fn find_opening_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
8    let mut pos = 0;
9    while pos < text.len() {
10        let Some(lt_pos) = text[pos..].find('<') else {
11            return None;
12        };
13        let abs_lt = pos + lt_pos;
14
15        let Some(gt_pos) = text[abs_lt..].find('>') else {
16            return None;
17        };
18        let tag_content = &text[abs_lt + 1..abs_lt + gt_pos];
19
20        // Check if tag matches base_name or prefix:base_name
21        let tag_name = tag_content.split_whitespace().next().unwrap_or("");
22        if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
23            return Some((abs_lt, abs_lt + gt_pos + 1));
24        }
25
26        pos = abs_lt + 1;
27    }
28    None
29}
30
31fn find_closing_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
32    find_closing_tag_with_nesting(text, base_name, 1)
33}
34
35fn find_closing_tag_with_nesting(
36    text: &str,
37    base_name: &str,
38    initial_depth: usize,
39) -> Option<(usize, usize)> {
40    let mut depth = initial_depth;
41    let mut pos = 0;
42
43    while pos < text.len() {
44        let open_pos = find_opening_tag(&text[pos..], base_name);
45        let close_pos = find_first_closing_tag(&text[pos..], base_name);
46
47        if let (Some((o_start, _)), Some((c_start, _))) = (open_pos, close_pos) {
48            if o_start < c_start {
49                depth += 1;
50                pos += o_start + 1;
51                continue;
52            }
53        }
54
55        if let Some((c_start, c_end)) = close_pos {
56            depth -= 1;
57            if depth == 0 {
58                return Some((pos + c_start, pos + c_end));
59            }
60            pos += c_end;
61            continue;
62        }
63
64        if let Some((o_start, _)) = open_pos {
65            depth += 1;
66            pos += o_start + 1;
67            continue;
68        }
69
70        return None;
71    }
72    None
73}
74
75fn find_first_closing_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
76    let mut pos = 0;
77    while pos < text.len() {
78        let Some(lt_pos) = text[pos..].find("</") else {
79            return None;
80        };
81        let abs_lt = pos + lt_pos;
82
83        let Some(gt_pos) = text[abs_lt..].find('>') else {
84            return None;
85        };
86        let tag_name = text[abs_lt + 2..abs_lt + gt_pos].trim();
87
88        if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
89            return Some((abs_lt, abs_lt + gt_pos + 1));
90        }
91
92        pos = abs_lt + 2;
93    }
94    None
95}
96
97fn find_named_opening_tag<'a>(text: &'a str, base_name: &str) -> Option<(usize, usize, &'a str)> {
98    let mut pos = 0;
99    while pos < text.len() {
100        let Some(lt_pos) = text[pos..].find('<') else {
101            return None;
102        };
103        let abs_lt = pos + lt_pos;
104
105        let Some(gt_pos) = text[abs_lt..].find('>') else {
106            return None;
107        };
108        let tag_content = &text[abs_lt + 1..abs_lt + gt_pos];
109
110        // Extract tag name (first word)
111        let tag_name = tag_content.split_whitespace().next().unwrap_or("");
112        if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
113            // Extract name attribute value
114            if let Some(name_start) = tag_content.find("name=\"") {
115                let value_start = name_start + 6;
116                if let Some(value_end) = tag_content[value_start..].find('"') {
117                    let name = &tag_content[value_start..value_start + value_end];
118                    return Some((abs_lt, abs_lt + gt_pos + 1, name));
119                }
120            }
121        }
122
123        pos = abs_lt + 1;
124    }
125    None
126}
127
128/// Permissive matching allows any XML prefix to handle variation in model outputs.
129pub fn parse_xml_tool_calls(text: &str) -> Result<(Vec<ToolUseData>, String)> {
130    let mut tool_calls = Vec::new();
131    let mut remaining_text = String::new();
132    let mut last_end = 0;
133
134    let mut search_start = 0;
135    while let Some((open_start, open_end)) =
136        find_opening_tag(&text[search_start..], "function_calls")
137    {
138        let abs_open_start = search_start + open_start;
139        let abs_open_end = search_start + open_end;
140
141        let Some((close_start, close_end)) =
142            find_closing_tag(&text[abs_open_end..], "function_calls")
143        else {
144            bail!("Unclosed function_calls tag at position {}", abs_open_start);
145        };
146        let abs_close_start = abs_open_end + close_start;
147        let abs_close_end = abs_open_end + close_end;
148
149        remaining_text.push_str(&text[last_end..abs_open_start]);
150
151        let block_content = &text[abs_open_end..abs_close_start];
152        let parsed = parse_invoke_blocks(block_content)?;
153        tool_calls.extend(parsed);
154
155        last_end = abs_close_end;
156        search_start = abs_close_end;
157    }
158
159    remaining_text.push_str(&text[last_end..]);
160
161    debug!(
162        tool_count = tool_calls.len(),
163        remaining_len = remaining_text.len(),
164        "Parsed XML tool calls"
165    );
166
167    Ok((tool_calls, remaining_text.trim().to_string()))
168}
169
170fn parse_invoke_blocks(content: &str) -> Result<Vec<ToolUseData>> {
171    let mut tool_calls = Vec::new();
172
173    let mut search_start = 0;
174    while let Some((_open_start, open_end, name)) =
175        find_named_opening_tag(&content[search_start..], "invoke")
176    {
177        let abs_open_end = search_start + open_end;
178
179        let Some((close_start, close_end)) = find_closing_tag(&content[abs_open_end..], "invoke")
180        else {
181            bail!("Unclosed invoke tag for tool '{}'", name);
182        };
183        let abs_close_start = abs_open_end + close_start;
184        let abs_close_end = abs_open_end + close_end;
185
186        let invoke_content = &content[abs_open_end..abs_close_start];
187        let parameters = parse_parameters(invoke_content)?;
188
189        tool_calls.push(ToolUseData {
190            id: Uuid::new_v4().to_string(),
191            name: name.to_string(),
192            arguments: parameters,
193        });
194
195        search_start = abs_close_end;
196    }
197
198    Ok(tool_calls)
199}
200
201fn parse_parameters(content: &str) -> Result<Value> {
202    let mut params = serde_json::Map::new();
203
204    let mut search_start = 0;
205    while let Some((_open_start, open_end, name)) =
206        find_named_opening_tag(&content[search_start..], "parameter")
207    {
208        let abs_open_end = search_start + open_end;
209
210        let Some((close_start, close_end)) =
211            find_closing_tag(&content[abs_open_end..], "parameter")
212        else {
213            bail!("Unclosed parameter tag for '{}'", name);
214        };
215        let abs_close_start = abs_open_end + close_start;
216        let abs_close_end = abs_open_end + close_end;
217
218        let value_str = &content[abs_open_end..abs_close_start];
219
220        // Specification requires arrays/objects as JSON, scalars as strings
221        let value = match serde_json::from_str(value_str) {
222            Ok(v) => v,
223            Err(_) => Value::String(value_str.to_string()),
224        };
225
226        params.insert(name.to_string(), value);
227        search_start = abs_close_end;
228    }
229
230    Ok(Value::Object(params))
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_parse_single_tool_call() {
239        let input = r#"Some text before
240<function_calls>
241<invoke name="test_tool">
242<parameter name="param1">value1</parameter>
243<parameter name="param2">42</parameter>
244</invoke>
245</function_calls>
246Some text after"#;
247
248        let (calls, remaining) = parse_xml_tool_calls(input).unwrap();
249
250        assert_eq!(calls.len(), 1);
251        assert_eq!(calls[0].name, "test_tool");
252        assert_eq!(calls[0].arguments["param1"], "value1");
253        assert_eq!(calls[0].arguments["param2"], 42);
254        assert!(remaining.contains("Some text before"));
255        assert!(remaining.contains("Some text after"));
256    }
257
258    #[test]
259    fn test_parse_multiple_tool_calls() {
260        let input = r#"<function_calls>
261<invoke name="tool1">
262<parameter name="a">1</parameter>
263</invoke>
264<invoke name="tool2">
265<parameter name="b">2</parameter>
266</invoke>
267</function_calls>"#;
268
269        let (calls, _) = parse_xml_tool_calls(input).unwrap();
270
271        assert_eq!(calls.len(), 2);
272        assert_eq!(calls[0].name, "tool1");
273        assert_eq!(calls[1].name, "tool2");
274    }
275
276    #[test]
277    fn test_parse_json_parameter() {
278        let input = r#"<function_calls>
279<invoke name="test">
280<parameter name="arr">["a", "b", "c"]</parameter>
281<parameter name="obj">{"key": "value"}</parameter>
282</invoke>
283</function_calls>"#;
284
285        let (calls, _) = parse_xml_tool_calls(input).unwrap();
286
287        assert_eq!(calls.len(), 1);
288        assert!(calls[0].arguments["arr"].is_array());
289        assert!(calls[0].arguments["obj"].is_object());
290    }
291
292    #[test]
293    fn test_no_tool_calls() {
294        let input = "Just regular text without any tool calls";
295        let (calls, remaining) = parse_xml_tool_calls(input).unwrap();
296
297        assert!(calls.is_empty());
298        assert_eq!(remaining, input);
299    }
300
301    #[test]
302    fn test_parse_with_xml_prefix() {
303        // Parser accepts any XML prefix on tags
304        let prefix = "antml";
305        let input = format!(
306            "<{}:function_calls>\n<{}:invoke name=\"prefixed_tool\">\n<{}:parameter name=\"key\">value</{}:parameter>\n</{}:invoke>\n</{}:function_calls>",
307            prefix, prefix, prefix, prefix, prefix, prefix
308        );
309
310        let (calls, _) = parse_xml_tool_calls(&input).unwrap();
311
312        assert_eq!(calls.len(), 1);
313        assert_eq!(calls[0].name, "prefixed_tool");
314        assert_eq!(calls[0].arguments["key"], "value");
315    }
316
317    #[test]
318    fn test_parse_with_mixed_prefixes() {
319        // Parser accepts different prefixes on different tags
320        let input = "<abc:function_calls>\n<xyz:invoke name=\"mixed\">\n<foo:parameter name=\"p\">val</bar:parameter>\n</qux:invoke>\n</def:function_calls>";
321
322        let (calls, _) = parse_xml_tool_calls(input).unwrap();
323
324        assert_eq!(calls.len(), 1);
325        assert_eq!(calls[0].name, "mixed");
326        assert_eq!(calls[0].arguments["p"], "val");
327    }
328
329    #[test]
330    fn test_nested_tool_call_in_parameter() {
331        // Scenario: AI writes a file containing an XML tool call example
332        let lt = '<';
333        let inner = format!(
334            "{}function_calls>\n{}invoke name=\"nested_example\">\n{}parameter name=\"k\">v{}/parameter>\n{}/invoke>\n{}/function_calls>",
335            lt, lt, lt, lt, lt, lt
336        );
337        let input = format!(
338            "{}function_calls>\n{}invoke name=\"write_file\">\n{}parameter name=\"path\">x.md{}/parameter>\n{}parameter name=\"content\">{}{}/parameter>\n{}/invoke>\n{}/function_calls>",
339            lt, lt, lt, lt, lt, inner, lt, lt, lt
340        );
341
342        let (calls, remaining) = parse_xml_tool_calls(&input).unwrap();
343
344        assert_eq!(calls.len(), 1);
345        assert_eq!(calls[0].name, "write_file");
346        assert_eq!(calls[0].arguments["path"], "x.md");
347
348        let content = calls[0].arguments.get("content");
349        assert!(content.is_some());
350        assert!(remaining.is_empty());
351    }
352}