1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7 errors::{ParserError, ParserResult},
8 parsers::helpers,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13pub struct QwenCoderParser {
25 extractor: Regex,
27
28 buffer: String,
30
31 prev_tool_call_arr: Vec<Value>,
33
34 current_tool_id: i32,
36
37 current_tool_name_sent: bool,
39
40 streamed_args_for_tool: Vec<String>,
42
43 tool_call_start_token: &'static str,
45 tool_call_end_token: &'static str,
46
47 in_tool_call: bool,
49 current_function_name: String,
50 current_parameters: serde_json::Map<String, Value>,
51
52 xml_function_pattern: Regex,
54 xml_param_pattern: Regex,
55}
56
57fn html_unescape(s: &str) -> String {
61 let mut result = String::with_capacity(s.len());
62 let mut chars = s.chars().peekable();
63
64 while let Some(c) = chars.next() {
65 if c == '&' {
66 let mut entity = String::new();
67 let mut consumed_semicolon = false;
68 while let Some(&next) = chars.peek() {
69 if next == ';' {
70 chars.next();
71 consumed_semicolon = true;
72 break;
73 }
74 if next.is_alphanumeric() || next == '#' {
75 entity.push(chars.next().unwrap());
76 } else {
77 break;
78 }
79 }
80
81 let decoded = match entity.as_str() {
82 "amp" => "&",
83 "lt" => "<",
84 "gt" => ">",
85 "quot" => "\"",
86 "apos" => "'",
87 "nbsp" => "\u{00A0}",
88 s if s.starts_with('#') => {
89 let num_str = &s[1..];
90 let code_point = if num_str.starts_with('x') || num_str.starts_with('X') {
91 u32::from_str_radix(&num_str[1..], 16).ok()
92 } else {
93 num_str.parse::<u32>().ok()
94 };
95 if let Some(cp) = code_point {
96 if let Some(ch) = char::from_u32(cp) {
97 result.push(ch);
98 continue;
99 }
100 }
101 result.push('&');
103 result.push_str(&entity);
104 if consumed_semicolon {
105 result.push(';');
106 }
107 continue;
108 }
109 _ => {
110 result.push('&');
112 result.push_str(&entity);
113 if consumed_semicolon {
114 result.push(';');
115 }
116 continue;
117 }
118 };
119 result.push_str(decoded);
120 } else {
121 result.push(c);
122 }
123 }
124
125 result
126}
127
128fn safe_val(raw: &str) -> Value {
134 let unescaped = html_unescape(raw.trim());
135
136 if let Ok(v) = serde_json::from_str::<Value>(&unescaped) {
138 return v;
139 }
140
141 match unescaped.as_str() {
143 "True" => return Value::Bool(true),
144 "False" => return Value::Bool(false),
145 "None" => return Value::Null,
146 _ => {}
147 }
148
149 Value::String(unescaped)
151}
152
153impl QwenCoderParser {
154 pub fn new() -> Self {
156 let pattern = r"(?s)<tool_call>\s*(.*?)\s*</tool_call>";
158 let extractor = Regex::new(pattern).expect("Valid regex pattern");
159
160 let xml_function_pattern =
162 Regex::new(r"<function=([^>]+)>").expect("Valid XML function pattern");
163 let xml_param_pattern = Regex::new(r"(?s)<parameter=([^>]+)>(.*?)</parameter>")
164 .expect("Valid XML parameter pattern");
165
166 Self {
167 extractor,
168 buffer: String::new(),
169 prev_tool_call_arr: Vec::new(),
170 current_tool_id: -1,
171 current_tool_name_sent: false,
172 streamed_args_for_tool: Vec::new(),
173 tool_call_start_token: "<tool_call>",
174 tool_call_end_token: "</tool_call>",
175 in_tool_call: false,
176 current_function_name: String::new(),
177 current_parameters: serde_json::Map::new(),
178 xml_function_pattern,
179 xml_param_pattern,
180 }
181 }
182
183 fn parse_xml_format(&self, content: &str) -> ParserResult<Option<ToolCall>> {
185 let function_captures = self
186 .xml_function_pattern
187 .captures(content)
188 .ok_or_else(|| ParserError::ParsingFailed("No function name found".to_string()))?;
189
190 let function_name = function_captures
191 .get(1)
192 .ok_or_else(|| ParserError::ParsingFailed("Function name capture failed".to_string()))?
193 .as_str()
194 .trim()
195 .to_string();
196
197 if function_name.is_empty() {
198 return Ok(None);
199 }
200
201 let mut parameters = serde_json::Map::new();
202
203 for cap in self.xml_param_pattern.captures_iter(content) {
204 if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
205 let key = key_match.as_str().trim().to_string();
206 let value = value_match.as_str();
207 let json_value = safe_val(value);
208 parameters.insert(key, json_value);
209 }
210 }
211
212 let arguments = serde_json::to_string(¶meters)
213 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
214
215 Ok(Some(ToolCall {
216 function: FunctionCall {
217 name: function_name,
218 arguments,
219 },
220 }))
221 }
222
223 fn parse_and_stream_parameters(&mut self) -> ParserResult<Vec<ToolCallItem>> {
226 let mut calls: Vec<ToolCallItem> = vec![];
227
228 let mut new_params = serde_json::Map::new();
230 for cap in self.xml_param_pattern.captures_iter(&self.buffer) {
231 if let (Some(key_match), Some(value_match)) = (cap.get(1), cap.get(2)) {
232 let key = key_match.as_str().trim().to_string();
233 let value = value_match.as_str();
234 let json_value = safe_val(value);
235 new_params.insert(key, json_value);
236 }
237 }
238
239 if new_params != self.current_parameters {
241 let current_args = &mut self.streamed_args_for_tool[self.current_tool_id as usize];
242
243 if self.current_parameters.is_empty() {
244 let mut items = Vec::new();
246 for (key, value) in &new_params {
247 let key_json =
248 serde_json::to_string(key).unwrap_or_else(|_| format!("\"{}\"", key));
249 let value_json = serde_json::to_string(value).unwrap_or_default();
250 items.push(format!("{}: {}", key_json, value_json));
251 }
252 let json_fragment = format!("{{{}", items.join(", "));
253
254 calls.push(ToolCallItem {
255 tool_index: self.current_tool_id as usize,
256 name: None,
257 parameters: json_fragment.clone(),
258 });
259 *current_args = json_fragment;
260 } else {
261 let new_keys: Vec<_> = new_params
263 .keys()
264 .filter(|k| !self.current_parameters.contains_key(*k))
265 .collect();
266
267 if !new_keys.is_empty() {
268 let mut continuation_parts = Vec::new();
269 for key in new_keys {
270 if let Some(value) = new_params.get(key) {
271 let key_json = serde_json::to_string(key)
272 .unwrap_or_else(|_| format!("\"{}\"", key));
273 let value_json = serde_json::to_string(value).unwrap_or_default();
274 continuation_parts.push(format!("{}: {}", key_json, value_json));
275 }
276 }
277
278 let json_fragment = format!(", {}", continuation_parts.join(", "));
279
280 calls.push(ToolCallItem {
281 tool_index: self.current_tool_id as usize,
282 name: None,
283 parameters: json_fragment.clone(),
284 });
285 current_args.push_str(&json_fragment);
286 }
287 }
288
289 self.current_parameters = new_params.clone();
291 if let Some(tool_obj) =
292 self.prev_tool_call_arr[self.current_tool_id as usize].as_object_mut()
293 {
294 tool_obj.insert("arguments".to_string(), Value::Object(new_params));
295 }
296 }
297
298 Ok(calls)
299 }
300
301 fn reset_streaming_state(&mut self) {
303 self.in_tool_call = false;
304 self.current_tool_name_sent = false;
305 self.current_function_name.clear();
306 self.current_parameters.clear();
307 }
308}
309
310impl Default for QwenCoderParser {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[async_trait]
317impl ToolParser for QwenCoderParser {
318 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
319 if !self.has_tool_markers(text) {
321 return Ok((text.to_string(), vec![]));
322 }
323
324 let idx = text.find(self.tool_call_start_token).unwrap();
326 let normal_text = text[..idx].to_string();
327
328 let mut tools = Vec::new();
330 for captures in self.extractor.captures_iter(text) {
331 if let Some(content_str) = captures.get(1) {
332 let content = content_str.as_str().trim();
333
334 match self.parse_xml_format(content) {
335 Ok(Some(tool)) => tools.push(tool),
336 Ok(None) => continue,
337 Err(e) => {
338 tracing::warn!("Failed to parse XML tool call: {:?}", e);
339 continue;
340 }
341 }
342 }
343 }
344
345 if tools.is_empty() {
347 return Ok((text.to_string(), vec![]));
348 }
349
350 Ok((normal_text, tools))
351 }
352
353 async fn parse_incremental(
354 &mut self,
355 chunk: &str,
356 tools: &[Tool],
357 ) -> ParserResult<StreamingParseResult> {
358 self.buffer.push_str(chunk);
359
360 let mut normal_text = String::new();
361 let mut calls: Vec<ToolCallItem> = vec![];
362
363 let tool_indices = helpers::get_tool_indices(tools);
365
366 loop {
367 if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
369 if helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
371 .is_none()
372 {
373 normal_text.push_str(&self.buffer);
374 self.buffer.clear();
375 }
376 break;
377 }
378
379 if !self.in_tool_call {
381 if let Some(s) = self.buffer.find(self.tool_call_start_token) {
382 normal_text.push_str(&self.buffer[..s]);
383 self.buffer = self.buffer[s + self.tool_call_start_token.len()..].to_string();
384 self.in_tool_call = true;
385 self.current_tool_name_sent = false;
386 self.current_function_name.clear();
387 self.current_parameters.clear();
388 continue;
389 } else {
390 break;
391 }
392 }
393
394 if !self.current_tool_name_sent {
396 if let Some(captures) = self.xml_function_pattern.captures(&self.buffer) {
397 if let Some(name_match) = captures.get(1) {
398 let function_name = name_match.as_str().trim().to_string();
399
400 if tool_indices.contains_key(&function_name) {
402 self.current_function_name = function_name.clone();
403 self.current_tool_name_sent = true;
404
405 if self.current_tool_id == -1 {
407 self.current_tool_id = 0;
408 }
409
410 helpers::ensure_capacity(
412 self.current_tool_id,
413 &mut self.prev_tool_call_arr,
414 &mut self.streamed_args_for_tool,
415 );
416
417 self.prev_tool_call_arr[self.current_tool_id as usize] = serde_json::json!({
419 "name": function_name,
420 "arguments": {}
421 });
422
423 calls.push(ToolCallItem {
425 tool_index: self.current_tool_id as usize,
426 name: Some(function_name),
427 parameters: String::new(),
428 });
429
430 self.buffer = self.buffer[captures.get(0).unwrap().end()..].to_string();
432 continue;
433 } else {
434 tracing::warn!("Invalid function name: {}", function_name);
436 self.reset_streaming_state();
437 normal_text.push_str(&self.buffer);
438 self.buffer.clear();
439 break;
440 }
441 }
442 } else {
443 break;
445 }
446 }
447
448 if self.current_tool_name_sent {
450 let param_calls = self.parse_and_stream_parameters()?;
451 calls.extend(param_calls);
452
453 if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
455 let current_args = &self.streamed_args_for_tool[self.current_tool_id as usize];
457 if !current_args.is_empty() {
458 let open_braces = current_args.matches('{').count();
460 let close_braces = current_args.matches('}').count();
461 if open_braces > close_braces {
462 calls.push(ToolCallItem {
463 tool_index: self.current_tool_id as usize,
464 name: None,
465 parameters: "}".to_string(),
466 });
467 self.streamed_args_for_tool[self.current_tool_id as usize].push('}');
468 }
469 }
470
471 self.buffer =
473 self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
474 self.reset_streaming_state();
475 self.current_tool_id += 1;
476 continue;
477 } else {
478 break;
480 }
481 }
482
483 break;
484 }
485
486 Ok(StreamingParseResult { normal_text, calls })
487 }
488
489 fn has_tool_markers(&self, text: &str) -> bool {
490 text.contains(self.tool_call_start_token)
491 }
492
493 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
494 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
495 }
496
497 fn reset(&mut self) {
498 helpers::reset_parser_state(
499 &mut self.buffer,
500 &mut self.prev_tool_call_arr,
501 &mut self.current_tool_id,
502 &mut self.current_tool_name_sent,
503 &mut self.streamed_args_for_tool,
504 );
505 self.reset_streaming_state();
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_html_unescape_basic() {
515 assert_eq!(html_unescape("&"), "&");
516 assert_eq!(html_unescape("<"), "<");
517 assert_eq!(html_unescape(">"), ">");
518 assert_eq!(html_unescape("""), "\"");
519 assert_eq!(html_unescape("'"), "'");
520 }
521
522 #[test]
523 fn test_html_unescape_numeric() {
524 assert_eq!(html_unescape("<"), "<");
525 assert_eq!(html_unescape("<"), "<");
526 assert_eq!(html_unescape("<"), "<");
527 }
528
529 #[test]
530 fn test_html_unescape_mixed() {
531 assert_eq!(
532 html_unescape("Hello & World <tag>"),
533 "Hello & World <tag>"
534 );
535 }
536
537 #[test]
538 fn test_html_unescape_unknown() {
539 assert_eq!(html_unescape("&unknown;"), "&unknown;");
541 assert_eq!(html_unescape("&foo bar"), "&foo bar");
543 assert_eq!(html_unescape("&"), "&");
544 assert_eq!(html_unescape("& "), "& ");
545 }
546
547 #[test]
548 fn test_safe_val_json() {
549 assert_eq!(safe_val("42"), Value::Number(42.into()));
550 assert_eq!(safe_val("1.5"), serde_json::json!(1.5));
551 assert_eq!(safe_val("true"), Value::Bool(true));
552 assert_eq!(safe_val("false"), Value::Bool(false));
553 assert_eq!(safe_val("null"), Value::Null);
554 assert_eq!(
555 safe_val(r#"{"key": "value"}"#),
556 serde_json::json!({"key": "value"})
557 );
558 assert_eq!(safe_val(r#"[1, 2, 3]"#), serde_json::json!([1, 2, 3]));
559 }
560
561 #[test]
562 fn test_safe_val_python_literals() {
563 assert_eq!(safe_val("True"), Value::Bool(true));
564 assert_eq!(safe_val("False"), Value::Bool(false));
565 assert_eq!(safe_val("None"), Value::Null);
566 }
567
568 #[test]
569 fn test_safe_val_string_fallback() {
570 assert_eq!(
571 safe_val("hello world"),
572 Value::String("hello world".to_string())
573 );
574 assert_eq!(safe_val(" spaces "), Value::String("spaces".to_string()));
575 }
576
577 #[test]
578 fn test_safe_val_html_entities() {
579 assert_eq!(safe_val("<div>"), Value::String("<div>".to_string()));
580 assert_eq!(
581 safe_val("Tom & Jerry"),
582 Value::String("Tom & Jerry".to_string())
583 );
584 }
585}