Skip to main content

tool_parser/parsers/
qwen_coder.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// Qwen Coder format parser for tool calls
14///
15/// Handles the Qwen Coder specific XML format:
16/// `<tool_call>\n<function=name>\n<parameter=key>value</parameter>\n</function>\n</tool_call>`
17///
18/// Features:
19/// - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
20/// - XML-style function declaration: `<function=name>`
21/// - XML-style parameters: `<parameter=key>value</parameter>`
22///
23/// Reference: https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8?chat_template=default
24pub struct QwenCoderParser {
25    /// Regex for extracting tool calls in parse_complete
26    extractor: Regex,
27
28    /// Buffer for accumulating incomplete patterns across chunks
29    buffer: String,
30
31    /// Stores complete tool call info (name and arguments) for each tool being parsed
32    prev_tool_call_arr: Vec<Value>,
33
34    /// Index of currently streaming tool call (-1 means no active tool)
35    current_tool_id: i32,
36
37    /// Flag for whether current tool's name has been sent to client
38    current_tool_name_sent: bool,
39
40    /// Tracks raw JSON string content streamed to client for each tool's arguments
41    streamed_args_for_tool: Vec<String>,
42
43    /// Token configuration
44    tool_call_start_token: &'static str,
45    tool_call_end_token: &'static str,
46
47    /// XML format streaming state
48    in_tool_call: bool,
49    current_function_name: String,
50    current_parameters: serde_json::Map<String, Value>,
51
52    /// Precompiled regex patterns for XML format parsing
53    xml_function_pattern: Regex,
54    xml_param_pattern: Regex,
55}
56
57/// Decode HTML entities in a string (equivalent to Python's html.unescape)
58///
59/// Handles common HTML entities like &amp; &lt; &gt; &quot; &#39; and numeric entities
60fn html_unescape(s: &str) -> String {
61    let mut result = String::with_capacity(s.len());
62    let mut chars = s.chars().peekable();
63
64    while let Some(c) = chars.next() {
65        if c == '&' {
66            let mut entity = String::new();
67            let mut consumed_semicolon = false;
68            while let Some(&next) = chars.peek() {
69                if next == ';' {
70                    chars.next();
71                    consumed_semicolon = true;
72                    break;
73                }
74                if next.is_alphanumeric() || next == '#' {
75                    // Safe: peek() returned Some, so next() will too
76                    if let Some(ch) = chars.next() {
77                        entity.push(ch);
78                    }
79                } else {
80                    break;
81                }
82            }
83
84            let decoded = match entity.as_str() {
85                "amp" => "&",
86                "lt" => "<",
87                "gt" => ">",
88                "quot" => "\"",
89                "apos" => "'",
90                "nbsp" => "\u{00A0}",
91                s if s.starts_with('#') => {
92                    let num_str = &s[1..];
93                    let code_point = if num_str.starts_with('x') || num_str.starts_with('X') {
94                        u32::from_str_radix(&num_str[1..], 16).ok()
95                    } else {
96                        num_str.parse::<u32>().ok()
97                    };
98                    if let Some(cp) = code_point {
99                        if let Some(ch) = char::from_u32(cp) {
100                            result.push(ch);
101                            continue;
102                        }
103                    }
104                    // Invalid numeric entity, reconstruct original
105                    result.push('&');
106                    result.push_str(&entity);
107                    if consumed_semicolon {
108                        result.push(';');
109                    }
110                    continue;
111                }
112                _ => {
113                    // Unknown entity, reconstruct original
114                    result.push('&');
115                    result.push_str(&entity);
116                    if consumed_semicolon {
117                        result.push(';');
118                    }
119                    continue;
120                }
121            };
122            result.push_str(decoded);
123        } else {
124            result.push(c);
125        }
126    }
127
128    result
129}
130
131/// Parse a raw parameter value, similar to Python's _safe_val
132///
133/// 1. Decode HTML entities
134/// 2. Try to parse as JSON (numbers, booleans, null, objects, arrays)
135/// 3. Fall back to string if JSON parsing fails
136fn safe_val(raw: &str) -> Value {
137    let unescaped = html_unescape(raw.trim());
138
139    // Try JSON parsing first
140    if let Ok(v) = serde_json::from_str::<Value>(&unescaped) {
141        return v;
142    }
143
144    // Handle Python-style literals (True, False, None)
145    match unescaped.as_str() {
146        "True" => return Value::Bool(true),
147        "False" => return Value::Bool(false),
148        "None" => return Value::Null,
149        _ => {}
150    }
151
152    // Fall back to string
153    Value::String(unescaped)
154}
155
156impl QwenCoderParser {
157    /// Create a new Qwen Coder parser
158    #[expect(
159        clippy::expect_used,
160        reason = "regex patterns are compile-time string literals"
161    )]
162    pub fn new() -> Self {
163        // Support XML format: <tool_call>\n<function=name>\n<parameter=key>value</parameter>\n</function>\n</tool_call>
164        let pattern = r"(?s)<tool_call>\s*(.*?)\s*</tool_call>";
165        let extractor = Regex::new(pattern).expect("Valid regex pattern");
166
167        // Precompile XML format regex patterns for performance
168        let xml_function_pattern =
169            Regex::new(r"<function=([^>]+)>").expect("Valid XML function pattern");
170        let xml_param_pattern = Regex::new(r"(?s)<parameter=([^>]+)>(.*?)</parameter>")
171            .expect("Valid XML parameter pattern");
172
173        Self {
174            extractor,
175            buffer: String::new(),
176            prev_tool_call_arr: Vec::new(),
177            current_tool_id: -1,
178            current_tool_name_sent: false,
179            streamed_args_for_tool: Vec::new(),
180            tool_call_start_token: "<tool_call>",
181            tool_call_end_token: "</tool_call>",
182            in_tool_call: false,
183            current_function_name: String::new(),
184            current_parameters: serde_json::Map::new(),
185            xml_function_pattern,
186            xml_param_pattern,
187        }
188    }
189
190    /// Parse XML format tool call: <function=name><parameter=key>value</parameter></function>
191    fn parse_xml_format(&self, content: &str) -> ParserResult<Option<ToolCall>> {
192        let function_captures = self
193            .xml_function_pattern
194            .captures(content)
195            .ok_or_else(|| ParserError::ParsingFailed("No function name found".to_string()))?;
196
197        let function_name = function_captures
198            .get(1)
199            .ok_or_else(|| ParserError::ParsingFailed("Function name capture failed".to_string()))?
200            .as_str()
201            .trim()
202            .to_string();
203
204        if function_name.is_empty() {
205            return Ok(None);
206        }
207
208        let mut parameters = serde_json::Map::new();
209
210        for cap in self.xml_param_pattern.captures_iter(content) {
211            if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
212                let key = key_match.as_str().trim().to_string();
213                let value = value_match.as_str();
214                let json_value = safe_val(value);
215                parameters.insert(key, json_value);
216            }
217        }
218
219        let arguments = serde_json::to_string(&parameters)
220            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
221
222        Ok(Some(ToolCall {
223            function: FunctionCall {
224                name: function_name,
225                arguments,
226            },
227        }))
228    }
229
230    /// Parse and stream complete parameters from buffer
231    /// Returns tool call items to emit (similar to Python's _parse_and_stream_parameters)
232    fn parse_and_stream_parameters(&mut self) -> Vec<ToolCallItem> {
233        let mut calls: Vec<ToolCallItem> = vec![];
234
235        // Find all complete parameter patterns in buffer
236        let mut new_params = serde_json::Map::new();
237        for cap in self.xml_param_pattern.captures_iter(&self.buffer) {
238            if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
239                let key = key_match.as_str().trim().to_string();
240                let value = value_match.as_str();
241                let json_value = safe_val(value);
242                new_params.insert(key, json_value);
243            }
244        }
245
246        // Calculate parameter diff and stream updates
247        if new_params != self.current_parameters {
248            let current_args = &mut self.streamed_args_for_tool[self.current_tool_id as usize];
249
250            if self.current_parameters.is_empty() {
251                // First parameter(s) - build JSON fragment (without closing brace)
252                let mut items = Vec::new();
253                for (key, value) in &new_params {
254                    let key_json =
255                        serde_json::to_string(key).unwrap_or_else(|_| format!("\"{key}\""));
256                    let value_json = serde_json::to_string(value).unwrap_or_default();
257                    items.push(format!("{key_json}: {value_json}"));
258                }
259                let json_fragment = format!("{{{}", items.join(", "));
260
261                calls.push(ToolCallItem {
262                    tool_index: self.current_tool_id as usize,
263                    name: None,
264                    parameters: json_fragment.clone(),
265                });
266                *current_args = json_fragment;
267            } else {
268                // Additional parameters - add them incrementally
269                let new_keys: Vec<_> = new_params
270                    .keys()
271                    .filter(|k| !self.current_parameters.contains_key(*k))
272                    .collect();
273
274                if !new_keys.is_empty() {
275                    let mut continuation_parts = Vec::new();
276                    for key in new_keys {
277                        if let Some(value) = new_params.get(key) {
278                            let key_json =
279                                serde_json::to_string(key).unwrap_or_else(|_| format!("\"{key}\""));
280                            let value_json = serde_json::to_string(value).unwrap_or_default();
281                            continuation_parts.push(format!("{key_json}: {value_json}"));
282                        }
283                    }
284
285                    let json_fragment = format!(", {}", continuation_parts.join(", "));
286
287                    calls.push(ToolCallItem {
288                        tool_index: self.current_tool_id as usize,
289                        name: None,
290                        parameters: json_fragment.clone(),
291                    });
292                    current_args.push_str(&json_fragment);
293                }
294            }
295
296            // Update current state
297            self.current_parameters.clone_from(&new_params);
298            if let Some(tool_obj) =
299                self.prev_tool_call_arr[self.current_tool_id as usize].as_object_mut()
300            {
301                tool_obj.insert("arguments".to_string(), Value::Object(new_params));
302            }
303        }
304
305        calls
306    }
307
308    /// Reset streaming state for next tool call
309    fn reset_streaming_state(&mut self) {
310        self.in_tool_call = false;
311        self.current_tool_name_sent = false;
312        self.current_function_name.clear();
313        self.current_parameters.clear();
314    }
315}
316
317impl Default for QwenCoderParser {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323#[async_trait]
324impl ToolParser for QwenCoderParser {
325    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
326        // Check if text contains Qwen Coder format
327        if !self.has_tool_markers(text) {
328            return Ok((text.to_string(), vec![]));
329        }
330
331        // Find where the first tool call begins
332        // Safe: has_tool_markers() already confirmed the marker exists
333        let idx = text
334            .find(self.tool_call_start_token)
335            .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
336        let normal_text = text[..idx].to_string();
337
338        // Extract tool calls
339        let mut tools = Vec::new();
340        for captures in self.extractor.captures_iter(text) {
341            if let Some(content_str) = captures.get(1) {
342                let content = content_str.as_str().trim();
343
344                match self.parse_xml_format(content) {
345                    Ok(Some(tool)) => tools.push(tool),
346                    Ok(None) => continue,
347                    Err(e) => {
348                        tracing::warn!("Failed to parse XML tool call: {:?}", e);
349                        continue;
350                    }
351                }
352            }
353        }
354
355        // If no tools were successfully parsed despite having markers, return entire text
356        if tools.is_empty() {
357            return Ok((text.to_string(), vec![]));
358        }
359
360        Ok((normal_text, tools))
361    }
362
363    async fn parse_incremental(
364        &mut self,
365        chunk: &str,
366        tools: &[Tool],
367    ) -> ParserResult<StreamingParseResult> {
368        self.buffer.push_str(chunk);
369
370        let mut normal_text = String::new();
371        let mut calls: Vec<ToolCallItem> = vec![];
372
373        // Build tool indices for validation
374        let tool_indices = helpers::get_tool_indices(tools);
375
376        loop {
377            // If we're not in a tool call and don't see a start token, return normal text
378            if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
379                // Check for partial start token
380                if helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
381                    .is_none()
382                {
383                    normal_text.push_str(&self.buffer);
384                    self.buffer.clear();
385                }
386                break;
387            }
388
389            // Look for tool call start
390            if !self.in_tool_call {
391                if let Some(s) = self.buffer.find(self.tool_call_start_token) {
392                    normal_text.push_str(&self.buffer[..s]);
393                    self.buffer = self.buffer[s + self.tool_call_start_token.len()..].to_string();
394                    self.in_tool_call = true;
395                    self.current_tool_name_sent = false;
396                    self.current_function_name.clear();
397                    self.current_parameters.clear();
398                    continue;
399                } else {
400                    break;
401                }
402            }
403
404            // We're in a tool call, try to parse function name if not sent yet
405            if !self.current_tool_name_sent {
406                if let Some(captures) = self.xml_function_pattern.captures(&self.buffer) {
407                    if let Some(name_match) = captures.get(1) {
408                        let function_name = name_match.as_str().trim().to_string();
409
410                        // Validate function name
411                        if tool_indices.contains_key(&function_name) {
412                            self.current_function_name.clone_from(&function_name);
413                            self.current_tool_name_sent = true;
414
415                            // Initialize tool call tracking
416                            if self.current_tool_id == -1 {
417                                self.current_tool_id = 0;
418                            }
419
420                            // Ensure tracking arrays are large enough
421                            helpers::ensure_capacity(
422                                self.current_tool_id,
423                                &mut self.prev_tool_call_arr,
424                                &mut self.streamed_args_for_tool,
425                            );
426
427                            // Store tool call info
428                            self.prev_tool_call_arr[self.current_tool_id as usize] = serde_json::json!({
429                                "name": function_name,
430                                "arguments": {}
431                            });
432
433                            // Send tool name
434                            calls.push(ToolCallItem {
435                                tool_index: self.current_tool_id as usize,
436                                name: Some(function_name),
437                                parameters: String::new(),
438                            });
439
440                            // Remove processed function declaration from buffer
441                            // Safe: captures.get(0) always returns Some (group 0 is the entire match)
442                            self.buffer =
443                                self.buffer[captures.get(0).map_or(0, |m| m.end())..].to_string();
444                            continue;
445                        } else {
446                            // Invalid function name, reset state
447                            tracing::warn!("Invalid function name: {}", function_name);
448                            self.reset_streaming_state();
449                            normal_text.push_str(&self.buffer);
450                            self.buffer.clear();
451                            break;
452                        }
453                    }
454                } else {
455                    // Function name not complete yet, wait for more text
456                    break;
457                }
458            }
459
460            // Parse parameters (only complete ones)
461            if self.current_tool_name_sent {
462                let param_calls = self.parse_and_stream_parameters();
463                calls.extend(param_calls);
464
465                // Check if tool call is complete
466                if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
467                    // Close JSON object if we have parameters
468                    let current_args = &self.streamed_args_for_tool[self.current_tool_id as usize];
469                    if !current_args.is_empty() {
470                        // Count braces to check if JSON is complete
471                        let open_braces = current_args.matches('{').count();
472                        let close_braces = current_args.matches('}').count();
473                        if open_braces > close_braces {
474                            calls.push(ToolCallItem {
475                                tool_index: self.current_tool_id as usize,
476                                name: None,
477                                parameters: "}".to_string(),
478                            });
479                            self.streamed_args_for_tool[self.current_tool_id as usize].push('}');
480                        }
481                    }
482
483                    // Complete the tool call
484                    self.buffer =
485                        self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
486                    self.reset_streaming_state();
487                    self.current_tool_id += 1;
488                    continue;
489                } else {
490                    // Tool call not complete yet, wait for more text
491                    break;
492                }
493            }
494
495            break;
496        }
497
498        Ok(StreamingParseResult { normal_text, calls })
499    }
500
501    fn has_tool_markers(&self, text: &str) -> bool {
502        text.contains(self.tool_call_start_token)
503    }
504
505    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
506        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
507    }
508
509    fn reset(&mut self) {
510        helpers::reset_parser_state(
511            &mut self.buffer,
512            &mut self.prev_tool_call_arr,
513            &mut self.current_tool_id,
514            &mut self.current_tool_name_sent,
515            &mut self.streamed_args_for_tool,
516        );
517        self.reset_streaming_state();
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_html_unescape_basic() {
527        assert_eq!(html_unescape("&amp;"), "&");
528        assert_eq!(html_unescape("&lt;"), "<");
529        assert_eq!(html_unescape("&gt;"), ">");
530        assert_eq!(html_unescape("&quot;"), "\"");
531        assert_eq!(html_unescape("&apos;"), "'");
532    }
533
534    #[test]
535    fn test_html_unescape_numeric() {
536        assert_eq!(html_unescape("&#60;"), "<");
537        assert_eq!(html_unescape("&#x3C;"), "<");
538        assert_eq!(html_unescape("&#x3c;"), "<");
539    }
540
541    #[test]
542    fn test_html_unescape_mixed() {
543        assert_eq!(
544            html_unescape("Hello &amp; World &lt;tag&gt;"),
545            "Hello & World <tag>"
546        );
547    }
548
549    #[test]
550    fn test_html_unescape_unknown() {
551        // Unknown entities with semicolon should be preserved as-is
552        assert_eq!(html_unescape("&unknown;"), "&unknown;");
553        // Unterminated entities should NOT have semicolon added
554        assert_eq!(html_unescape("&foo bar"), "&foo bar");
555        assert_eq!(html_unescape("&"), "&");
556        assert_eq!(html_unescape("& "), "& ");
557    }
558
559    #[test]
560    fn test_safe_val_json() {
561        assert_eq!(safe_val("42"), Value::Number(42.into()));
562        assert_eq!(safe_val("1.5"), serde_json::json!(1.5));
563        assert_eq!(safe_val("true"), Value::Bool(true));
564        assert_eq!(safe_val("false"), Value::Bool(false));
565        assert_eq!(safe_val("null"), Value::Null);
566        assert_eq!(
567            safe_val(r#"{"key": "value"}"#),
568            serde_json::json!({"key": "value"})
569        );
570        assert_eq!(safe_val(r"[1, 2, 3]"), serde_json::json!([1, 2, 3]));
571    }
572
573    #[test]
574    fn test_safe_val_python_literals() {
575        assert_eq!(safe_val("True"), Value::Bool(true));
576        assert_eq!(safe_val("False"), Value::Bool(false));
577        assert_eq!(safe_val("None"), Value::Null);
578    }
579
580    #[test]
581    fn test_safe_val_string_fallback() {
582        assert_eq!(
583            safe_val("hello world"),
584            Value::String("hello world".to_string())
585        );
586        assert_eq!(safe_val("  spaces  "), Value::String("spaces".to_string()));
587    }
588
589    #[test]
590    fn test_safe_val_html_entities() {
591        assert_eq!(safe_val("&lt;div&gt;"), Value::String("<div>".to_string()));
592        assert_eq!(
593            safe_val("Tom &amp; Jerry"),
594            Value::String("Tom & Jerry".to_string())
595        );
596    }
597}