1use crate::ai::ToolUseData;
2use anyhow::{bail, Result};
3use serde_json::Value;
4use tracing::debug;
5use uuid::Uuid;
6
7fn find_opening_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
8 let mut pos = 0;
9 while pos < text.len() {
10 let Some(lt_pos) = text[pos..].find('<') else {
11 return None;
12 };
13 let abs_lt = pos + lt_pos;
14
15 let Some(gt_pos) = text[abs_lt..].find('>') else {
16 return None;
17 };
18 let tag_content = &text[abs_lt + 1..abs_lt + gt_pos];
19
20 let tag_name = tag_content.split_whitespace().next().unwrap_or("");
22 if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
23 return Some((abs_lt, abs_lt + gt_pos + 1));
24 }
25
26 pos = abs_lt + 1;
27 }
28 None
29}
30
31fn find_closing_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
32 find_closing_tag_with_nesting(text, base_name, 1)
33}
34
35fn find_closing_tag_with_nesting(
36 text: &str,
37 base_name: &str,
38 initial_depth: usize,
39) -> Option<(usize, usize)> {
40 let mut depth = initial_depth;
41 let mut pos = 0;
42
43 while pos < text.len() {
44 let open_pos = find_opening_tag(&text[pos..], base_name);
45 let close_pos = find_first_closing_tag(&text[pos..], base_name);
46
47 if let (Some((o_start, _)), Some((c_start, _))) = (open_pos, close_pos) {
48 if o_start < c_start {
49 depth += 1;
50 pos += o_start + 1;
51 continue;
52 }
53 }
54
55 if let Some((c_start, c_end)) = close_pos {
56 depth -= 1;
57 if depth == 0 {
58 return Some((pos + c_start, pos + c_end));
59 }
60 pos += c_end;
61 continue;
62 }
63
64 if let Some((o_start, _)) = open_pos {
65 depth += 1;
66 pos += o_start + 1;
67 continue;
68 }
69
70 return None;
71 }
72 None
73}
74
75fn find_first_closing_tag(text: &str, base_name: &str) -> Option<(usize, usize)> {
76 let mut pos = 0;
77 while pos < text.len() {
78 let Some(lt_pos) = text[pos..].find("</") else {
79 return None;
80 };
81 let abs_lt = pos + lt_pos;
82
83 let Some(gt_pos) = text[abs_lt..].find('>') else {
84 return None;
85 };
86 let tag_name = text[abs_lt + 2..abs_lt + gt_pos].trim();
87
88 if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
89 return Some((abs_lt, abs_lt + gt_pos + 1));
90 }
91
92 pos = abs_lt + 2;
93 }
94 None
95}
96
97fn find_named_opening_tag<'a>(text: &'a str, base_name: &str) -> Option<(usize, usize, &'a str)> {
98 let mut pos = 0;
99 while pos < text.len() {
100 let Some(lt_pos) = text[pos..].find('<') else {
101 return None;
102 };
103 let abs_lt = pos + lt_pos;
104
105 let Some(gt_pos) = text[abs_lt..].find('>') else {
106 return None;
107 };
108 let tag_content = &text[abs_lt + 1..abs_lt + gt_pos];
109
110 let tag_name = tag_content.split_whitespace().next().unwrap_or("");
112 if tag_name == base_name || tag_name.ends_with(&format!(":{}", base_name)) {
113 if let Some(name_start) = tag_content.find("name=\"") {
115 let value_start = name_start + 6;
116 if let Some(value_end) = tag_content[value_start..].find('"') {
117 let name = &tag_content[value_start..value_start + value_end];
118 return Some((abs_lt, abs_lt + gt_pos + 1, name));
119 }
120 }
121 }
122
123 pos = abs_lt + 1;
124 }
125 None
126}
127
128pub fn parse_xml_tool_calls(text: &str) -> Result<(Vec<ToolUseData>, String)> {
130 let mut tool_calls = Vec::new();
131 let mut remaining_text = String::new();
132 let mut last_end = 0;
133
134 let mut search_start = 0;
135 while let Some((open_start, open_end)) =
136 find_opening_tag(&text[search_start..], "function_calls")
137 {
138 let abs_open_start = search_start + open_start;
139 let abs_open_end = search_start + open_end;
140
141 let Some((close_start, close_end)) =
142 find_closing_tag(&text[abs_open_end..], "function_calls")
143 else {
144 bail!("Unclosed function_calls tag at position {}", abs_open_start);
145 };
146 let abs_close_start = abs_open_end + close_start;
147 let abs_close_end = abs_open_end + close_end;
148
149 remaining_text.push_str(&text[last_end..abs_open_start]);
150
151 let block_content = &text[abs_open_end..abs_close_start];
152 let parsed = parse_invoke_blocks(block_content)?;
153 tool_calls.extend(parsed);
154
155 last_end = abs_close_end;
156 search_start = abs_close_end;
157 }
158
159 remaining_text.push_str(&text[last_end..]);
160
161 debug!(
162 tool_count = tool_calls.len(),
163 remaining_len = remaining_text.len(),
164 "Parsed XML tool calls"
165 );
166
167 Ok((tool_calls, remaining_text.trim().to_string()))
168}
169
170fn parse_invoke_blocks(content: &str) -> Result<Vec<ToolUseData>> {
171 let mut tool_calls = Vec::new();
172
173 let mut search_start = 0;
174 while let Some((_open_start, open_end, name)) =
175 find_named_opening_tag(&content[search_start..], "invoke")
176 {
177 let abs_open_end = search_start + open_end;
178
179 let Some((close_start, close_end)) = find_closing_tag(&content[abs_open_end..], "invoke")
180 else {
181 bail!("Unclosed invoke tag for tool '{}'", name);
182 };
183 let abs_close_start = abs_open_end + close_start;
184 let abs_close_end = abs_open_end + close_end;
185
186 let invoke_content = &content[abs_open_end..abs_close_start];
187 let parameters = parse_parameters(invoke_content)?;
188
189 tool_calls.push(ToolUseData {
190 id: Uuid::new_v4().to_string(),
191 name: name.to_string(),
192 arguments: parameters,
193 });
194
195 search_start = abs_close_end;
196 }
197
198 Ok(tool_calls)
199}
200
201fn parse_parameters(content: &str) -> Result<Value> {
202 let mut params = serde_json::Map::new();
203
204 let mut search_start = 0;
205 while let Some((_open_start, open_end, name)) =
206 find_named_opening_tag(&content[search_start..], "parameter")
207 {
208 let abs_open_end = search_start + open_end;
209
210 let Some((close_start, close_end)) =
211 find_closing_tag(&content[abs_open_end..], "parameter")
212 else {
213 bail!("Unclosed parameter tag for '{}'", name);
214 };
215 let abs_close_start = abs_open_end + close_start;
216 let abs_close_end = abs_open_end + close_end;
217
218 let value_str = &content[abs_open_end..abs_close_start];
219
220 let value = match serde_json::from_str(value_str) {
222 Ok(v) => v,
223 Err(_) => Value::String(value_str.to_string()),
224 };
225
226 params.insert(name.to_string(), value);
227 search_start = abs_close_end;
228 }
229
230 Ok(Value::Object(params))
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_parse_single_tool_call() {
239 let input = r#"Some text before
240<function_calls>
241<invoke name="test_tool">
242<parameter name="param1">value1</parameter>
243<parameter name="param2">42</parameter>
244</invoke>
245</function_calls>
246Some text after"#;
247
248 let (calls, remaining) = parse_xml_tool_calls(input).unwrap();
249
250 assert_eq!(calls.len(), 1);
251 assert_eq!(calls[0].name, "test_tool");
252 assert_eq!(calls[0].arguments["param1"], "value1");
253 assert_eq!(calls[0].arguments["param2"], 42);
254 assert!(remaining.contains("Some text before"));
255 assert!(remaining.contains("Some text after"));
256 }
257
258 #[test]
259 fn test_parse_multiple_tool_calls() {
260 let input = r#"<function_calls>
261<invoke name="tool1">
262<parameter name="a">1</parameter>
263</invoke>
264<invoke name="tool2">
265<parameter name="b">2</parameter>
266</invoke>
267</function_calls>"#;
268
269 let (calls, _) = parse_xml_tool_calls(input).unwrap();
270
271 assert_eq!(calls.len(), 2);
272 assert_eq!(calls[0].name, "tool1");
273 assert_eq!(calls[1].name, "tool2");
274 }
275
276 #[test]
277 fn test_parse_json_parameter() {
278 let input = r#"<function_calls>
279<invoke name="test">
280<parameter name="arr">["a", "b", "c"]</parameter>
281<parameter name="obj">{"key": "value"}</parameter>
282</invoke>
283</function_calls>"#;
284
285 let (calls, _) = parse_xml_tool_calls(input).unwrap();
286
287 assert_eq!(calls.len(), 1);
288 assert!(calls[0].arguments["arr"].is_array());
289 assert!(calls[0].arguments["obj"].is_object());
290 }
291
292 #[test]
293 fn test_no_tool_calls() {
294 let input = "Just regular text without any tool calls";
295 let (calls, remaining) = parse_xml_tool_calls(input).unwrap();
296
297 assert!(calls.is_empty());
298 assert_eq!(remaining, input);
299 }
300
301 #[test]
302 fn test_parse_with_xml_prefix() {
303 let prefix = "antml";
305 let input = format!(
306 "<{}:function_calls>\n<{}:invoke name=\"prefixed_tool\">\n<{}:parameter name=\"key\">value</{}:parameter>\n</{}:invoke>\n</{}:function_calls>",
307 prefix, prefix, prefix, prefix, prefix, prefix
308 );
309
310 let (calls, _) = parse_xml_tool_calls(&input).unwrap();
311
312 assert_eq!(calls.len(), 1);
313 assert_eq!(calls[0].name, "prefixed_tool");
314 assert_eq!(calls[0].arguments["key"], "value");
315 }
316
317 #[test]
318 fn test_parse_with_mixed_prefixes() {
319 let input = "<abc:function_calls>\n<xyz:invoke name=\"mixed\">\n<foo:parameter name=\"p\">val</bar:parameter>\n</qux:invoke>\n</def:function_calls>";
321
322 let (calls, _) = parse_xml_tool_calls(input).unwrap();
323
324 assert_eq!(calls.len(), 1);
325 assert_eq!(calls[0].name, "mixed");
326 assert_eq!(calls[0].arguments["p"], "val");
327 }
328
329 #[test]
330 fn test_nested_tool_call_in_parameter() {
331 let lt = '<';
333 let inner = format!(
334 "{}function_calls>\n{}invoke name=\"nested_example\">\n{}parameter name=\"k\">v{}/parameter>\n{}/invoke>\n{}/function_calls>",
335 lt, lt, lt, lt, lt, lt
336 );
337 let input = format!(
338 "{}function_calls>\n{}invoke name=\"write_file\">\n{}parameter name=\"path\">x.md{}/parameter>\n{}parameter name=\"content\">{}{}/parameter>\n{}/invoke>\n{}/function_calls>",
339 lt, lt, lt, lt, lt, inner, lt, lt, lt
340 );
341
342 let (calls, remaining) = parse_xml_tool_calls(&input).unwrap();
343
344 assert_eq!(calls.len(), 1);
345 assert_eq!(calls[0].name, "write_file");
346 assert_eq!(calls[0].arguments["path"], "x.md");
347
348 let content = calls[0].arguments.get("content");
349 assert!(content.is_some());
350 assert!(remaining.is_empty());
351 }
352}