tool_parser/parsers/
qwen_coder.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// Qwen Coder format parser for tool calls
14///
15/// Handles the Qwen Coder specific XML format:
16/// `<tool_call>\n<function=name>\n<parameter=key>value</parameter>\n</function>\n</tool_call>`
17///
18/// Features:
19/// - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
20/// - XML-style function declaration: `<function=name>`
21/// - XML-style parameters: `<parameter=key>value</parameter>`
22///
23/// Reference: https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8?chat_template=default
24pub struct QwenCoderParser {
25    /// Regex for extracting tool calls in parse_complete
26    extractor: Regex,
27
28    /// Buffer for accumulating incomplete patterns across chunks
29    buffer: String,
30
31    /// Stores complete tool call info (name and arguments) for each tool being parsed
32    prev_tool_call_arr: Vec<Value>,
33
34    /// Index of currently streaming tool call (-1 means no active tool)
35    current_tool_id: i32,
36
37    /// Flag for whether current tool's name has been sent to client
38    current_tool_name_sent: bool,
39
40    /// Tracks raw JSON string content streamed to client for each tool's arguments
41    streamed_args_for_tool: Vec<String>,
42
43    /// Token configuration
44    tool_call_start_token: &'static str,
45    tool_call_end_token: &'static str,
46
47    /// XML format streaming state
48    in_tool_call: bool,
49    current_function_name: String,
50    current_parameters: serde_json::Map<String, Value>,
51
52    /// Precompiled regex patterns for XML format parsing
53    xml_function_pattern: Regex,
54    xml_param_pattern: Regex,
55}
56
57/// Decode HTML entities in a string (equivalent to Python's html.unescape)
58///
59/// Handles common HTML entities like &amp; &lt; &gt; &quot; &#39; and numeric entities
60fn html_unescape(s: &str) -> String {
61    let mut result = String::with_capacity(s.len());
62    let mut chars = s.chars().peekable();
63
64    while let Some(c) = chars.next() {
65        if c == '&' {
66            let mut entity = String::new();
67            let mut consumed_semicolon = false;
68            while let Some(&next) = chars.peek() {
69                if next == ';' {
70                    chars.next();
71                    consumed_semicolon = true;
72                    break;
73                }
74                if next.is_alphanumeric() || next == '#' {
75                    entity.push(chars.next().unwrap());
76                } else {
77                    break;
78                }
79            }
80
81            let decoded = match entity.as_str() {
82                "amp" => "&",
83                "lt" => "<",
84                "gt" => ">",
85                "quot" => "\"",
86                "apos" => "'",
87                "nbsp" => "\u{00A0}",
88                s if s.starts_with('#') => {
89                    let num_str = &s[1..];
90                    let code_point = if num_str.starts_with('x') || num_str.starts_with('X') {
91                        u32::from_str_radix(&num_str[1..], 16).ok()
92                    } else {
93                        num_str.parse::<u32>().ok()
94                    };
95                    if let Some(cp) = code_point {
96                        if let Some(ch) = char::from_u32(cp) {
97                            result.push(ch);
98                            continue;
99                        }
100                    }
101                    // Invalid numeric entity, reconstruct original
102                    result.push('&');
103                    result.push_str(&entity);
104                    if consumed_semicolon {
105                        result.push(';');
106                    }
107                    continue;
108                }
109                _ => {
110                    // Unknown entity, reconstruct original
111                    result.push('&');
112                    result.push_str(&entity);
113                    if consumed_semicolon {
114                        result.push(';');
115                    }
116                    continue;
117                }
118            };
119            result.push_str(decoded);
120        } else {
121            result.push(c);
122        }
123    }
124
125    result
126}
127
128/// Parse a raw parameter value, similar to Python's _safe_val
129///
130/// 1. Decode HTML entities
131/// 2. Try to parse as JSON (numbers, booleans, null, objects, arrays)
132/// 3. Fall back to string if JSON parsing fails
133fn safe_val(raw: &str) -> Value {
134    let unescaped = html_unescape(raw.trim());
135
136    // Try JSON parsing first
137    if let Ok(v) = serde_json::from_str::<Value>(&unescaped) {
138        return v;
139    }
140
141    // Handle Python-style literals (True, False, None)
142    match unescaped.as_str() {
143        "True" => return Value::Bool(true),
144        "False" => return Value::Bool(false),
145        "None" => return Value::Null,
146        _ => {}
147    }
148
149    // Fall back to string
150    Value::String(unescaped)
151}
152
153impl QwenCoderParser {
154    /// Create a new Qwen Coder parser
155    pub fn new() -> Self {
156        // Support XML format: <tool_call>\n<function=name>\n<parameter=key>value</parameter>\n</function>\n</tool_call>
157        let pattern = r"(?s)<tool_call>\s*(.*?)\s*</tool_call>";
158        let extractor = Regex::new(pattern).expect("Valid regex pattern");
159
160        // Precompile XML format regex patterns for performance
161        let xml_function_pattern =
162            Regex::new(r"<function=([^>]+)>").expect("Valid XML function pattern");
163        let xml_param_pattern = Regex::new(r"(?s)<parameter=([^>]+)>(.*?)</parameter>")
164            .expect("Valid XML parameter pattern");
165
166        Self {
167            extractor,
168            buffer: String::new(),
169            prev_tool_call_arr: Vec::new(),
170            current_tool_id: -1,
171            current_tool_name_sent: false,
172            streamed_args_for_tool: Vec::new(),
173            tool_call_start_token: "<tool_call>",
174            tool_call_end_token: "</tool_call>",
175            in_tool_call: false,
176            current_function_name: String::new(),
177            current_parameters: serde_json::Map::new(),
178            xml_function_pattern,
179            xml_param_pattern,
180        }
181    }
182
183    /// Parse XML format tool call: <function=name><parameter=key>value</parameter></function>
184    fn parse_xml_format(&self, content: &str) -> ParserResult<Option<ToolCall>> {
185        let function_captures = self
186            .xml_function_pattern
187            .captures(content)
188            .ok_or_else(|| ParserError::ParsingFailed("No function name found".to_string()))?;
189
190        let function_name = function_captures
191            .get(1)
192            .ok_or_else(|| ParserError::ParsingFailed("Function name capture failed".to_string()))?
193            .as_str()
194            .trim()
195            .to_string();
196
197        if function_name.is_empty() {
198            return Ok(None);
199        }
200
201        let mut parameters = serde_json::Map::new();
202
203        for cap in self.xml_param_pattern.captures_iter(content) {
204            if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
205                let key = key_match.as_str().trim().to_string();
206                let value = value_match.as_str();
207                let json_value = safe_val(value);
208                parameters.insert(key, json_value);
209            }
210        }
211
212        let arguments = serde_json::to_string(&parameters)
213            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
214
215        Ok(Some(ToolCall {
216            function: FunctionCall {
217                name: function_name,
218                arguments,
219            },
220        }))
221    }
222
223    /// Parse and stream complete parameters from buffer
224    /// Returns tool call items to emit (similar to Python's _parse_and_stream_parameters)
225    fn parse_and_stream_parameters(&mut self) -> ParserResult<Vec<ToolCallItem>> {
226        let mut calls: Vec<ToolCallItem> = vec![];
227
228        // Find all complete parameter patterns in buffer
229        let mut new_params = serde_json::Map::new();
230        for cap in self.xml_param_pattern.captures_iter(&self.buffer) {
231            if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
232                let key = key_match.as_str().trim().to_string();
233                let value = value_match.as_str();
234                let json_value = safe_val(value);
235                new_params.insert(key, json_value);
236            }
237        }
238
239        // Calculate parameter diff and stream updates
240        if new_params != self.current_parameters {
241            let current_args = &mut self.streamed_args_for_tool[self.current_tool_id as usize];
242
243            if self.current_parameters.is_empty() {
244                // First parameter(s) - build JSON fragment (without closing brace)
245                let mut items = Vec::new();
246                for (key, value) in &new_params {
247                    let key_json =
248                        serde_json::to_string(key).unwrap_or_else(|_| format!("\"{}\"", key));
249                    let value_json = serde_json::to_string(value).unwrap_or_default();
250                    items.push(format!("{}: {}", key_json, value_json));
251                }
252                let json_fragment = format!("{{{}", items.join(", "));
253
254                calls.push(ToolCallItem {
255                    tool_index: self.current_tool_id as usize,
256                    name: None,
257                    parameters: json_fragment.clone(),
258                });
259                *current_args = json_fragment;
260            } else {
261                // Additional parameters - add them incrementally
262                let new_keys: Vec<_> = new_params
263                    .keys()
264                    .filter(|k| !self.current_parameters.contains_key(*k))
265                    .collect();
266
267                if !new_keys.is_empty() {
268                    let mut continuation_parts = Vec::new();
269                    for key in new_keys {
270                        if let Some(value) = new_params.get(key) {
271                            let key_json = serde_json::to_string(key)
272                                .unwrap_or_else(|_| format!("\"{}\"", key));
273                            let value_json = serde_json::to_string(value).unwrap_or_default();
274                            continuation_parts.push(format!("{}: {}", key_json, value_json));
275                        }
276                    }
277
278                    let json_fragment = format!(", {}", continuation_parts.join(", "));
279
280                    calls.push(ToolCallItem {
281                        tool_index: self.current_tool_id as usize,
282                        name: None,
283                        parameters: json_fragment.clone(),
284                    });
285                    current_args.push_str(&json_fragment);
286                }
287            }
288
289            // Update current state
290            self.current_parameters = new_params.clone();
291            if let Some(tool_obj) =
292                self.prev_tool_call_arr[self.current_tool_id as usize].as_object_mut()
293            {
294                tool_obj.insert("arguments".to_string(), Value::Object(new_params));
295            }
296        }
297
298        Ok(calls)
299    }
300
301    /// Reset streaming state for next tool call
302    fn reset_streaming_state(&mut self) {
303        self.in_tool_call = false;
304        self.current_tool_name_sent = false;
305        self.current_function_name.clear();
306        self.current_parameters.clear();
307    }
308}
309
310impl Default for QwenCoderParser {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316#[async_trait]
317impl ToolParser for QwenCoderParser {
318    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
319        // Check if text contains Qwen Coder format
320        if !self.has_tool_markers(text) {
321            return Ok((text.to_string(), vec![]));
322        }
323
324        // Find where the first tool call begins
325        let idx = text.find(self.tool_call_start_token).unwrap();
326        let normal_text = text[..idx].to_string();
327
328        // Extract tool calls
329        let mut tools = Vec::new();
330        for captures in self.extractor.captures_iter(text) {
331            if let Some(content_str) = captures.get(1) {
332                let content = content_str.as_str().trim();
333
334                match self.parse_xml_format(content) {
335                    Ok(Some(tool)) => tools.push(tool),
336                    Ok(None) => continue,
337                    Err(e) => {
338                        tracing::warn!("Failed to parse XML tool call: {:?}", e);
339                        continue;
340                    }
341                }
342            }
343        }
344
345        // If no tools were successfully parsed despite having markers, return entire text
346        if tools.is_empty() {
347            return Ok((text.to_string(), vec![]));
348        }
349
350        Ok((normal_text, tools))
351    }
352
353    async fn parse_incremental(
354        &mut self,
355        chunk: &str,
356        tools: &[Tool],
357    ) -> ParserResult<StreamingParseResult> {
358        self.buffer.push_str(chunk);
359
360        let mut normal_text = String::new();
361        let mut calls: Vec<ToolCallItem> = vec![];
362
363        // Build tool indices for validation
364        let tool_indices = helpers::get_tool_indices(tools);
365
366        loop {
367            // If we're not in a tool call and don't see a start token, return normal text
368            if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
369                // Check for partial start token
370                if helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
371                    .is_none()
372                {
373                    normal_text.push_str(&self.buffer);
374                    self.buffer.clear();
375                }
376                break;
377            }
378
379            // Look for tool call start
380            if !self.in_tool_call {
381                if let Some(s) = self.buffer.find(self.tool_call_start_token) {
382                    normal_text.push_str(&self.buffer[..s]);
383                    self.buffer = self.buffer[s + self.tool_call_start_token.len()..].to_string();
384                    self.in_tool_call = true;
385                    self.current_tool_name_sent = false;
386                    self.current_function_name.clear();
387                    self.current_parameters.clear();
388                    continue;
389                } else {
390                    break;
391                }
392            }
393
394            // We're in a tool call, try to parse function name if not sent yet
395            if !self.current_tool_name_sent {
396                if let Some(captures) = self.xml_function_pattern.captures(&self.buffer) {
397                    if let Some(name_match) = captures.get(1) {
398                        let function_name = name_match.as_str().trim().to_string();
399
400                        // Validate function name
401                        if tool_indices.contains_key(&function_name) {
402                            self.current_function_name = function_name.clone();
403                            self.current_tool_name_sent = true;
404
405                            // Initialize tool call tracking
406                            if self.current_tool_id == -1 {
407                                self.current_tool_id = 0;
408                            }
409
410                            // Ensure tracking arrays are large enough
411                            helpers::ensure_capacity(
412                                self.current_tool_id,
413                                &mut self.prev_tool_call_arr,
414                                &mut self.streamed_args_for_tool,
415                            );
416
417                            // Store tool call info
418                            self.prev_tool_call_arr[self.current_tool_id as usize] = serde_json::json!({
419                                "name": function_name,
420                                "arguments": {}
421                            });
422
423                            // Send tool name
424                            calls.push(ToolCallItem {
425                                tool_index: self.current_tool_id as usize,
426                                name: Some(function_name),
427                                parameters: String::new(),
428                            });
429
430                            // Remove processed function declaration from buffer
431                            self.buffer = self.buffer[captures.get(0).unwrap().end()..].to_string();
432                            continue;
433                        } else {
434                            // Invalid function name, reset state
435                            tracing::warn!("Invalid function name: {}", function_name);
436                            self.reset_streaming_state();
437                            normal_text.push_str(&self.buffer);
438                            self.buffer.clear();
439                            break;
440                        }
441                    }
442                } else {
443                    // Function name not complete yet, wait for more text
444                    break;
445                }
446            }
447
448            // Parse parameters (only complete ones)
449            if self.current_tool_name_sent {
450                let param_calls = self.parse_and_stream_parameters()?;
451                calls.extend(param_calls);
452
453                // Check if tool call is complete
454                if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
455                    // Close JSON object if we have parameters
456                    let current_args = &self.streamed_args_for_tool[self.current_tool_id as usize];
457                    if !current_args.is_empty() {
458                        // Count braces to check if JSON is complete
459                        let open_braces = current_args.matches('{').count();
460                        let close_braces = current_args.matches('}').count();
461                        if open_braces > close_braces {
462                            calls.push(ToolCallItem {
463                                tool_index: self.current_tool_id as usize,
464                                name: None,
465                                parameters: "}".to_string(),
466                            });
467                            self.streamed_args_for_tool[self.current_tool_id as usize].push('}');
468                        }
469                    }
470
471                    // Complete the tool call
472                    self.buffer =
473                        self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
474                    self.reset_streaming_state();
475                    self.current_tool_id += 1;
476                    continue;
477                } else {
478                    // Tool call not complete yet, wait for more text
479                    break;
480                }
481            }
482
483            break;
484        }
485
486        Ok(StreamingParseResult { normal_text, calls })
487    }
488
489    fn has_tool_markers(&self, text: &str) -> bool {
490        text.contains(self.tool_call_start_token)
491    }
492
493    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
494        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
495    }
496
497    fn reset(&mut self) {
498        helpers::reset_parser_state(
499            &mut self.buffer,
500            &mut self.prev_tool_call_arr,
501            &mut self.current_tool_id,
502            &mut self.current_tool_name_sent,
503            &mut self.streamed_args_for_tool,
504        );
505        self.reset_streaming_state();
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_html_unescape_basic() {
515        assert_eq!(html_unescape("&amp;"), "&");
516        assert_eq!(html_unescape("&lt;"), "<");
517        assert_eq!(html_unescape("&gt;"), ">");
518        assert_eq!(html_unescape("&quot;"), "\"");
519        assert_eq!(html_unescape("&apos;"), "'");
520    }
521
522    #[test]
523    fn test_html_unescape_numeric() {
524        assert_eq!(html_unescape("&#60;"), "<");
525        assert_eq!(html_unescape("&#x3C;"), "<");
526        assert_eq!(html_unescape("&#x3c;"), "<");
527    }
528
529    #[test]
530    fn test_html_unescape_mixed() {
531        assert_eq!(
532            html_unescape("Hello &amp; World &lt;tag&gt;"),
533            "Hello & World <tag>"
534        );
535    }
536
537    #[test]
538    fn test_html_unescape_unknown() {
539        // Unknown entities with semicolon should be preserved as-is
540        assert_eq!(html_unescape("&unknown;"), "&unknown;");
541        // Unterminated entities should NOT have semicolon added
542        assert_eq!(html_unescape("&foo bar"), "&foo bar");
543        assert_eq!(html_unescape("&"), "&");
544        assert_eq!(html_unescape("& "), "& ");
545    }
546
547    #[test]
548    fn test_safe_val_json() {
549        assert_eq!(safe_val("42"), Value::Number(42.into()));
550        assert_eq!(safe_val("1.5"), serde_json::json!(1.5));
551        assert_eq!(safe_val("true"), Value::Bool(true));
552        assert_eq!(safe_val("false"), Value::Bool(false));
553        assert_eq!(safe_val("null"), Value::Null);
554        assert_eq!(
555            safe_val(r#"{"key": "value"}"#),
556            serde_json::json!({"key": "value"})
557        );
558        assert_eq!(safe_val(r#"[1, 2, 3]"#), serde_json::json!([1, 2, 3]));
559    }
560
561    #[test]
562    fn test_safe_val_python_literals() {
563        assert_eq!(safe_val("True"), Value::Bool(true));
564        assert_eq!(safe_val("False"), Value::Bool(false));
565        assert_eq!(safe_val("None"), Value::Null);
566    }
567
568    #[test]
569    fn test_safe_val_string_fallback() {
570        assert_eq!(
571            safe_val("hello world"),
572            Value::String("hello world".to_string())
573        );
574        assert_eq!(safe_val("  spaces  "), Value::String("spaces".to_string()));
575    }
576
577    #[test]
578    fn test_safe_val_html_entities() {
579        assert_eq!(safe_val("&lt;div&gt;"), Value::String("<div>".to_string()));
580        assert_eq!(
581            safe_val("Tom &amp; Jerry"),
582            Value::String("Tom & Jerry".to_string())
583        );
584    }
585}