1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7 errors::{ParserError, ParserResult},
8 parsers::helpers,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13pub struct QwenCoderParser {
25 extractor: Regex,
27
28 buffer: String,
30
31 prev_tool_call_arr: Vec<Value>,
33
34 current_tool_id: i32,
36
37 current_tool_name_sent: bool,
39
40 streamed_args_for_tool: Vec<String>,
42
43 tool_call_start_token: &'static str,
45 tool_call_end_token: &'static str,
46
47 in_tool_call: bool,
49 current_function_name: String,
50 current_parameters: serde_json::Map<String, Value>,
51
52 xml_function_pattern: Regex,
54 xml_param_pattern: Regex,
55}
56
57fn html_unescape(s: &str) -> String {
61 let mut result = String::with_capacity(s.len());
62 let mut chars = s.chars().peekable();
63
64 while let Some(c) = chars.next() {
65 if c == '&' {
66 let mut entity = String::new();
67 let mut consumed_semicolon = false;
68 while let Some(&next) = chars.peek() {
69 if next == ';' {
70 chars.next();
71 consumed_semicolon = true;
72 break;
73 }
74 if next.is_alphanumeric() || next == '#' {
75 if let Some(ch) = chars.next() {
77 entity.push(ch);
78 }
79 } else {
80 break;
81 }
82 }
83
84 let decoded = match entity.as_str() {
85 "amp" => "&",
86 "lt" => "<",
87 "gt" => ">",
88 "quot" => "\"",
89 "apos" => "'",
90 "nbsp" => "\u{00A0}",
91 s if s.starts_with('#') => {
92 let num_str = &s[1..];
93 let code_point = if num_str.starts_with('x') || num_str.starts_with('X') {
94 u32::from_str_radix(&num_str[1..], 16).ok()
95 } else {
96 num_str.parse::<u32>().ok()
97 };
98 if let Some(cp) = code_point {
99 if let Some(ch) = char::from_u32(cp) {
100 result.push(ch);
101 continue;
102 }
103 }
104 result.push('&');
106 result.push_str(&entity);
107 if consumed_semicolon {
108 result.push(';');
109 }
110 continue;
111 }
112 _ => {
113 result.push('&');
115 result.push_str(&entity);
116 if consumed_semicolon {
117 result.push(';');
118 }
119 continue;
120 }
121 };
122 result.push_str(decoded);
123 } else {
124 result.push(c);
125 }
126 }
127
128 result
129}
130
131fn safe_val(raw: &str) -> Value {
137 let unescaped = html_unescape(raw.trim());
138
139 if let Ok(v) = serde_json::from_str::<Value>(&unescaped) {
141 return v;
142 }
143
144 match unescaped.as_str() {
146 "True" => return Value::Bool(true),
147 "False" => return Value::Bool(false),
148 "None" => return Value::Null,
149 _ => {}
150 }
151
152 Value::String(unescaped)
154}
155
156impl QwenCoderParser {
157 #[expect(
159 clippy::expect_used,
160 reason = "regex patterns are compile-time string literals"
161 )]
162 pub fn new() -> Self {
163 let pattern = r"(?s)<tool_call>\s*(.*?)\s*</tool_call>";
165 let extractor = Regex::new(pattern).expect("Valid regex pattern");
166
167 let xml_function_pattern =
169 Regex::new(r"<function=([^>]+)>").expect("Valid XML function pattern");
170 let xml_param_pattern = Regex::new(r"(?s)<parameter=([^>]+)>(.*?)</parameter>")
171 .expect("Valid XML parameter pattern");
172
173 Self {
174 extractor,
175 buffer: String::new(),
176 prev_tool_call_arr: Vec::new(),
177 current_tool_id: -1,
178 current_tool_name_sent: false,
179 streamed_args_for_tool: Vec::new(),
180 tool_call_start_token: "<tool_call>",
181 tool_call_end_token: "</tool_call>",
182 in_tool_call: false,
183 current_function_name: String::new(),
184 current_parameters: serde_json::Map::new(),
185 xml_function_pattern,
186 xml_param_pattern,
187 }
188 }
189
190 fn parse_xml_format(&self, content: &str) -> ParserResult<Option<ToolCall>> {
192 let function_captures = self
193 .xml_function_pattern
194 .captures(content)
195 .ok_or_else(|| ParserError::ParsingFailed("No function name found".to_string()))?;
196
197 let function_name = function_captures
198 .get(1)
199 .ok_or_else(|| ParserError::ParsingFailed("Function name capture failed".to_string()))?
200 .as_str()
201 .trim()
202 .to_string();
203
204 if function_name.is_empty() {
205 return Ok(None);
206 }
207
208 let mut parameters = serde_json::Map::new();
209
210 for cap in self.xml_param_pattern.captures_iter(content) {
211 if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
212 let key = key_match.as_str().trim().to_string();
213 let value = value_match.as_str();
214 let json_value = safe_val(value);
215 parameters.insert(key, json_value);
216 }
217 }
218
219 let arguments = serde_json::to_string(¶meters)
220 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
221
222 Ok(Some(ToolCall {
223 function: FunctionCall {
224 name: function_name,
225 arguments,
226 },
227 }))
228 }
229
230 fn parse_and_stream_parameters(&mut self) -> Vec<ToolCallItem> {
233 let mut calls: Vec<ToolCallItem> = vec![];
234
235 let mut new_params = serde_json::Map::new();
237 for cap in self.xml_param_pattern.captures_iter(&self.buffer) {
238 if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
239 let key = key_match.as_str().trim().to_string();
240 let value = value_match.as_str();
241 let json_value = safe_val(value);
242 new_params.insert(key, json_value);
243 }
244 }
245
246 if new_params != self.current_parameters {
248 let current_args = &mut self.streamed_args_for_tool[self.current_tool_id as usize];
249
250 if self.current_parameters.is_empty() {
251 let mut items = Vec::new();
253 for (key, value) in &new_params {
254 let key_json =
255 serde_json::to_string(key).unwrap_or_else(|_| format!("\"{key}\""));
256 let value_json = serde_json::to_string(value).unwrap_or_default();
257 items.push(format!("{key_json}: {value_json}"));
258 }
259 let json_fragment = format!("{{{}", items.join(", "));
260
261 calls.push(ToolCallItem {
262 tool_index: self.current_tool_id as usize,
263 name: None,
264 parameters: json_fragment.clone(),
265 });
266 *current_args = json_fragment;
267 } else {
268 let new_keys: Vec<_> = new_params
270 .keys()
271 .filter(|k| !self.current_parameters.contains_key(*k))
272 .collect();
273
274 if !new_keys.is_empty() {
275 let mut continuation_parts = Vec::new();
276 for key in new_keys {
277 if let Some(value) = new_params.get(key) {
278 let key_json =
279 serde_json::to_string(key).unwrap_or_else(|_| format!("\"{key}\""));
280 let value_json = serde_json::to_string(value).unwrap_or_default();
281 continuation_parts.push(format!("{key_json}: {value_json}"));
282 }
283 }
284
285 let json_fragment = format!(", {}", continuation_parts.join(", "));
286
287 calls.push(ToolCallItem {
288 tool_index: self.current_tool_id as usize,
289 name: None,
290 parameters: json_fragment.clone(),
291 });
292 current_args.push_str(&json_fragment);
293 }
294 }
295
296 self.current_parameters.clone_from(&new_params);
298 if let Some(tool_obj) =
299 self.prev_tool_call_arr[self.current_tool_id as usize].as_object_mut()
300 {
301 tool_obj.insert("arguments".to_string(), Value::Object(new_params));
302 }
303 }
304
305 calls
306 }
307
308 fn reset_streaming_state(&mut self) {
310 self.in_tool_call = false;
311 self.current_tool_name_sent = false;
312 self.current_function_name.clear();
313 self.current_parameters.clear();
314 }
315}
316
317impl Default for QwenCoderParser {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323#[async_trait]
324impl ToolParser for QwenCoderParser {
325 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
326 if !self.has_tool_markers(text) {
328 return Ok((text.to_string(), vec![]));
329 }
330
331 let idx = text
334 .find(self.tool_call_start_token)
335 .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
336 let normal_text = text[..idx].to_string();
337
338 let mut tools = Vec::new();
340 for captures in self.extractor.captures_iter(text) {
341 if let Some(content_str) = captures.get(1) {
342 let content = content_str.as_str().trim();
343
344 match self.parse_xml_format(content) {
345 Ok(Some(tool)) => tools.push(tool),
346 Ok(None) => continue,
347 Err(e) => {
348 tracing::warn!("Failed to parse XML tool call: {:?}", e);
349 continue;
350 }
351 }
352 }
353 }
354
355 if tools.is_empty() {
357 return Ok((text.to_string(), vec![]));
358 }
359
360 Ok((normal_text, tools))
361 }
362
363 async fn parse_incremental(
364 &mut self,
365 chunk: &str,
366 tools: &[Tool],
367 ) -> ParserResult<StreamingParseResult> {
368 self.buffer.push_str(chunk);
369
370 let mut normal_text = String::new();
371 let mut calls: Vec<ToolCallItem> = vec![];
372
373 let tool_indices = helpers::get_tool_indices(tools);
375
376 loop {
377 if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
379 if helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
381 .is_none()
382 {
383 normal_text.push_str(&self.buffer);
384 self.buffer.clear();
385 }
386 break;
387 }
388
389 if !self.in_tool_call {
391 if let Some(s) = self.buffer.find(self.tool_call_start_token) {
392 normal_text.push_str(&self.buffer[..s]);
393 self.buffer = self.buffer[s + self.tool_call_start_token.len()..].to_string();
394 self.in_tool_call = true;
395 self.current_tool_name_sent = false;
396 self.current_function_name.clear();
397 self.current_parameters.clear();
398 continue;
399 } else {
400 break;
401 }
402 }
403
404 if !self.current_tool_name_sent {
406 if let Some(captures) = self.xml_function_pattern.captures(&self.buffer) {
407 if let Some(name_match) = captures.get(1) {
408 let function_name = name_match.as_str().trim().to_string();
409
410 if tool_indices.contains_key(&function_name) {
412 self.current_function_name.clone_from(&function_name);
413 self.current_tool_name_sent = true;
414
415 if self.current_tool_id == -1 {
417 self.current_tool_id = 0;
418 }
419
420 helpers::ensure_capacity(
422 self.current_tool_id,
423 &mut self.prev_tool_call_arr,
424 &mut self.streamed_args_for_tool,
425 );
426
427 self.prev_tool_call_arr[self.current_tool_id as usize] = serde_json::json!({
429 "name": function_name,
430 "arguments": {}
431 });
432
433 calls.push(ToolCallItem {
435 tool_index: self.current_tool_id as usize,
436 name: Some(function_name),
437 parameters: String::new(),
438 });
439
440 self.buffer =
443 self.buffer[captures.get(0).map_or(0, |m| m.end())..].to_string();
444 continue;
445 } else {
446 tracing::warn!("Invalid function name: {}", function_name);
448 self.reset_streaming_state();
449 normal_text.push_str(&self.buffer);
450 self.buffer.clear();
451 break;
452 }
453 }
454 } else {
455 break;
457 }
458 }
459
460 if self.current_tool_name_sent {
462 let param_calls = self.parse_and_stream_parameters();
463 calls.extend(param_calls);
464
465 if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
467 let current_args = &self.streamed_args_for_tool[self.current_tool_id as usize];
469 if !current_args.is_empty() {
470 let open_braces = current_args.matches('{').count();
472 let close_braces = current_args.matches('}').count();
473 if open_braces > close_braces {
474 calls.push(ToolCallItem {
475 tool_index: self.current_tool_id as usize,
476 name: None,
477 parameters: "}".to_string(),
478 });
479 self.streamed_args_for_tool[self.current_tool_id as usize].push('}');
480 }
481 }
482
483 self.buffer =
485 self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
486 self.reset_streaming_state();
487 self.current_tool_id += 1;
488 continue;
489 } else {
490 break;
492 }
493 }
494
495 break;
496 }
497
498 Ok(StreamingParseResult { normal_text, calls })
499 }
500
501 fn has_tool_markers(&self, text: &str) -> bool {
502 text.contains(self.tool_call_start_token)
503 }
504
505 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
506 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
507 }
508
509 fn reset(&mut self) {
510 helpers::reset_parser_state(
511 &mut self.buffer,
512 &mut self.prev_tool_call_arr,
513 &mut self.current_tool_id,
514 &mut self.current_tool_name_sent,
515 &mut self.streamed_args_for_tool,
516 );
517 self.reset_streaming_state();
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_html_unescape_basic() {
527 assert_eq!(html_unescape("&"), "&");
528 assert_eq!(html_unescape("<"), "<");
529 assert_eq!(html_unescape(">"), ">");
530 assert_eq!(html_unescape("""), "\"");
531 assert_eq!(html_unescape("'"), "'");
532 }
533
534 #[test]
535 fn test_html_unescape_numeric() {
536 assert_eq!(html_unescape("<"), "<");
537 assert_eq!(html_unescape("<"), "<");
538 assert_eq!(html_unescape("<"), "<");
539 }
540
541 #[test]
542 fn test_html_unescape_mixed() {
543 assert_eq!(
544 html_unescape("Hello & World <tag>"),
545 "Hello & World <tag>"
546 );
547 }
548
549 #[test]
550 fn test_html_unescape_unknown() {
551 assert_eq!(html_unescape("&unknown;"), "&unknown;");
553 assert_eq!(html_unescape("&foo bar"), "&foo bar");
555 assert_eq!(html_unescape("&"), "&");
556 assert_eq!(html_unescape("& "), "& ");
557 }
558
559 #[test]
560 fn test_safe_val_json() {
561 assert_eq!(safe_val("42"), Value::Number(42.into()));
562 assert_eq!(safe_val("1.5"), serde_json::json!(1.5));
563 assert_eq!(safe_val("true"), Value::Bool(true));
564 assert_eq!(safe_val("false"), Value::Bool(false));
565 assert_eq!(safe_val("null"), Value::Null);
566 assert_eq!(
567 safe_val(r#"{"key": "value"}"#),
568 serde_json::json!({"key": "value"})
569 );
570 assert_eq!(safe_val(r"[1, 2, 3]"), serde_json::json!([1, 2, 3]));
571 }
572
573 #[test]
574 fn test_safe_val_python_literals() {
575 assert_eq!(safe_val("True"), Value::Bool(true));
576 assert_eq!(safe_val("False"), Value::Bool(false));
577 assert_eq!(safe_val("None"), Value::Null);
578 }
579
580 #[test]
581 fn test_safe_val_string_fallback() {
582 assert_eq!(
583 safe_val("hello world"),
584 Value::String("hello world".to_string())
585 );
586 assert_eq!(safe_val(" spaces "), Value::String("spaces".to_string()));
587 }
588
589 #[test]
590 fn test_safe_val_html_entities() {
591 assert_eq!(safe_val("<div>"), Value::String("<div>".to_string()));
592 assert_eq!(
593 safe_val("Tom & Jerry"),
594 Value::String("Tom & Jerry".to_string())
595 );
596 }
597}