1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9 errors::{ParserError, ParserResult},
10 parsers::helpers,
11 traits::ToolParser,
12 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15pub struct Step3Parser {
25 tool_call_extractor: Regex,
27 invoke_extractor: Regex,
29 param_extractor: Regex,
31
32 buffer: String,
34
35 bot_token: &'static str,
37 eot_token: &'static str,
38 tool_call_begin: &'static str,
39 tool_call_end: &'static str,
40 tool_sep: &'static str,
41
42 in_tool_block: bool,
44 tool_block_finished: bool,
45 current_function_name: String,
46 current_parameters: serde_json::Map<String, Value>,
47 in_tool_call: bool,
48 function_name_sent: bool,
49
50 prev_tool_call_arr: Vec<Value>,
52 current_tool_id: i32,
53 streamed_args_for_tool: Vec<String>,
54}
55
56impl Step3Parser {
57 #[expect(
59 clippy::expect_used,
60 reason = "regex patterns are compile-time string literals"
61 )]
62 pub fn new() -> Self {
63 let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
65 let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
66
67 let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
69 let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
70
71 let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
73 let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
74
75 Self {
76 tool_call_extractor,
77 invoke_extractor,
78 param_extractor,
79
80 buffer: String::new(),
81
82 bot_token: "<|tool_calls_begin|>",
83 eot_token: "<|tool_calls_end|>",
84 tool_call_begin: "<|tool_call_begin|>",
85 tool_call_end: "<|tool_call_end|>",
86 tool_sep: "<|tool_sep|>",
87
88 in_tool_block: false,
90 tool_block_finished: false,
91 current_function_name: String::new(),
92 current_parameters: serde_json::Map::new(),
93 in_tool_call: false,
94 function_name_sent: false,
95
96 prev_tool_call_arr: Vec::new(),
98 current_tool_id: -1,
99 streamed_args_for_tool: Vec::new(),
100 }
101 }
102
103 fn reset_streaming_state(&mut self) {
105 self.in_tool_call = false;
106 self.function_name_sent = false;
107 self.current_function_name.clear();
108 self.current_parameters.clear();
109 }
110
111 fn parse_partial_tool_call(
113 &mut self,
114 tool_indices: &HashMap<String, usize>,
115 ) -> StreamingParseResult {
116 let mut calls = Vec::new();
117
118 if !self.buffer.contains(self.tool_sep) {
120 return StreamingParseResult {
121 normal_text: String::new(),
122 calls,
123 };
124 }
125
126 let buffer_clone = self.buffer.clone();
128 let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
129 if parts.len() != 2 {
130 return StreamingParseResult {
131 normal_text: String::new(),
132 calls,
133 };
134 }
135
136 let type_part = parts[0].trim();
137 let invoke_part = parts[1];
138
139 if type_part != "function" {
141 self.reset_streaming_state();
143 return StreamingParseResult {
144 normal_text: String::new(),
145 calls,
146 };
147 }
148
149 if !self.function_name_sent {
151 if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
152 let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
153
154 if tool_indices.contains_key(func_name) {
156 self.current_function_name = func_name.to_string();
157 self.function_name_sent = true;
158
159 if self.current_tool_id == -1 {
161 self.current_tool_id = 0;
162 }
163
164 helpers::ensure_capacity(
166 self.current_tool_id,
167 &mut self.prev_tool_call_arr,
168 &mut self.streamed_args_for_tool,
169 );
170
171 let tool_id = self.current_tool_id as usize;
173 self.prev_tool_call_arr[tool_id] = serde_json::json!({
174 "name": func_name,
175 "arguments": {},
176 });
177
178 calls.push(ToolCallItem {
180 tool_index: self.current_tool_id as usize,
181 name: Some(func_name.to_string()),
182 parameters: String::new(),
183 });
184 } else {
185 tracing::debug!("Invalid function name: {}", func_name);
187 self.reset_streaming_state();
188 return StreamingParseResult {
189 normal_text: String::new(),
190 calls,
191 };
192 }
193 } else {
194 return StreamingParseResult {
196 normal_text: String::new(),
197 calls,
198 };
199 }
200 }
201
202 if self.function_name_sent {
204 let mut new_params = serde_json::Map::new();
206 for capture in self.param_extractor.captures_iter(invoke_part) {
207 let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
208 let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
209
210 let param_value =
212 if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
213 json_val
214 } else {
215 if param_value_str == "true" || param_value_str == "True" {
217 Value::Bool(true)
218 } else if param_value_str == "false" || param_value_str == "False" {
219 Value::Bool(false)
220 } else if param_value_str == "null" || param_value_str == "None" {
221 Value::Null
222 } else if let Ok(num) = param_value_str.parse::<i64>() {
223 Value::Number(num.into())
224 } else if let Ok(num) = param_value_str.parse::<f64>() {
225 if let Some(n) = serde_json::Number::from_f64(num) {
226 Value::Number(n)
227 } else {
228 Value::String(param_value_str.to_string())
229 }
230 } else {
231 Value::String(param_value_str.to_string())
232 }
233 };
234
235 new_params.insert(param_name.to_string(), param_value);
236 }
237
238 if new_params != self.current_parameters {
240 let diff = if self.current_parameters.is_empty() {
242 let params_content =
244 serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
245 if params_content.len() > 2 {
246 params_content[..params_content.len() - 1].to_string()
248 } else {
249 "{".to_string()
250 }
251 } else {
252 let old_json = serde_json::to_string(&self.current_parameters)
254 .unwrap_or_else(|_| "{}".to_string());
255 let new_json =
256 serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
257
258 let old_without_brace = &old_json[..old_json.len() - 1];
260 let new_without_brace = &new_json[..new_json.len() - 1];
261
262 new_without_brace
264 .strip_prefix(old_without_brace)
265 .map(|s| s.to_string())
266 .unwrap_or_default()
267 };
268
269 if !diff.is_empty() {
270 calls.push(ToolCallItem {
271 tool_index: self.current_tool_id as usize,
272 name: None,
273 parameters: diff.clone(),
274 });
275 let tool_id = self.current_tool_id as usize;
276 if tool_id < self.streamed_args_for_tool.len() {
277 self.streamed_args_for_tool[tool_id].push_str(&diff);
278 }
279 }
280
281 self.current_parameters.clone_from(&new_params);
283 let tool_id = self.current_tool_id as usize;
284 if tool_id < self.prev_tool_call_arr.len() {
285 if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
286 obj.insert("arguments".to_string(), Value::Object(new_params));
287 }
288 }
289 }
290
291 if self.buffer.contains(self.tool_call_end) {
293 let tool_id = self.current_tool_id as usize;
295 if tool_id < self.streamed_args_for_tool.len()
296 && !self.streamed_args_for_tool[tool_id].is_empty()
297 {
298 calls.push(ToolCallItem {
299 tool_index: self.current_tool_id as usize,
300 name: None,
301 parameters: "}".to_string(),
302 });
303 self.streamed_args_for_tool[tool_id].push('}');
304 }
305
306 if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
308 self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
310 }
311
312 self.reset_streaming_state();
314 self.current_tool_id += 1;
315 }
316 }
317
318 StreamingParseResult {
319 normal_text: String::new(),
320 calls,
321 }
322 }
323
324 fn parse_steptml_parameters(&self, params_text: &str) -> serde_json::Map<String, Value> {
326 let mut parameters = serde_json::Map::new();
327
328 for capture in self.param_extractor.captures_iter(params_text) {
329 let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
330 let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
331
332 let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
334 json_val
335 } else {
336 if param_value_str == "true" || param_value_str == "True" {
338 Value::Bool(true)
339 } else if param_value_str == "false" || param_value_str == "False" {
340 Value::Bool(false)
341 } else if param_value_str == "null" || param_value_str == "None" {
342 Value::Null
343 } else if let Ok(num) = param_value_str.parse::<i64>() {
344 Value::Number(num.into())
345 } else if let Ok(num) = param_value_str.parse::<f64>() {
346 if let Some(n) = serde_json::Number::from_f64(num) {
347 Value::Number(n)
348 } else {
349 Value::String(param_value_str.to_string())
350 }
351 } else {
352 Value::String(param_value_str.to_string())
353 }
354 };
355
356 parameters.insert(param_name.to_string(), param_value);
357 }
358
359 parameters
360 }
361
362 fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
364 if !block.contains("function") || !block.contains("<|tool_sep|>") {
366 return Ok(None);
367 }
368
369 let parts: Vec<&str> = block.split("<|tool_sep|>").collect();
371 if parts.len() != 2 {
372 return Ok(None);
373 }
374
375 if !parts[0].contains("function") {
377 return Ok(None);
378 }
379
380 let invoke_part = parts[1];
381
382 if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
384 let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
385
386 if func_name.is_empty() {
388 return Ok(None);
389 }
390
391 let params_text = captures.get(2).map_or("", |m| m.as_str());
392
393 let parameters = self.parse_steptml_parameters(params_text);
395
396 let arguments_str = serde_json::to_string(¶meters)
397 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
398
399 Ok(Some(ToolCall {
400 function: FunctionCall {
401 name: func_name.to_string(),
402 arguments: arguments_str,
403 },
404 }))
405 } else {
406 Ok(None)
407 }
408 }
409}
410
411impl Default for Step3Parser {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417#[async_trait]
418impl ToolParser for Step3Parser {
419 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
420 if !self.has_tool_markers(text) {
421 return Ok((text.to_string(), vec![]));
422 }
423
424 let idx = text
427 .find("<|tool_calls_begin|>")
428 .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
429 let normal_text = text[..idx].to_string();
430
431 let mut tools = Vec::new();
433 for mat in self.tool_call_extractor.find_iter(text) {
434 match self.parse_tool_call(mat.as_str()) {
435 Ok(Some(tool)) => tools.push(tool),
436 Ok(None) => continue,
437 Err(e) => {
438 tracing::debug!("Failed to parse tool call: {}", e);
439 continue;
440 }
441 }
442 }
443
444 if tools.is_empty() {
446 return Ok((text.to_string(), vec![]));
447 }
448
449 Ok((normal_text, tools))
450 }
451
452 async fn parse_incremental(
453 &mut self,
454 chunk: &str,
455 tools: &[Tool],
456 ) -> ParserResult<StreamingParseResult> {
457 self.buffer.push_str(chunk);
458
459 let tool_indices = helpers::get_tool_indices(tools);
461
462 if self.tool_block_finished {
464 let normal_text = std::mem::take(&mut self.buffer);
465 return Ok(StreamingParseResult {
466 normal_text,
467 calls: vec![],
468 });
469 }
470
471 if !self.in_tool_block {
473 if self.buffer.contains(self.bot_token) {
474 let idx = self.buffer.find(self.bot_token).ok_or_else(|| {
476 ParserError::ParsingFailed("token not found in buffer".to_string())
477 })?;
478 let normal_text = self.buffer[..idx].to_string();
479 self.buffer = self.buffer[idx + self.bot_token.len()..].to_string();
480 self.in_tool_block = true;
481 return Ok(StreamingParseResult {
482 normal_text,
483 calls: vec![],
484 });
485 } else {
486 if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_some() {
488 return Ok(StreamingParseResult::default()); } else {
490 let normal_text = std::mem::take(&mut self.buffer);
491 return Ok(StreamingParseResult {
492 normal_text,
493 calls: vec![],
494 });
495 }
496 }
497 }
498
499 let mut calls = Vec::new();
501
502 if self.buffer.contains(self.eot_token) {
504 let idx = self.buffer.find(self.eot_token).ok_or_else(|| {
506 ParserError::ParsingFailed("token not found in buffer".to_string())
507 })?;
508
509 if self.in_tool_call {
511 let before_eot = &self.buffer[..idx];
513 if before_eot.contains(self.tool_call_end) {
514 let result = self.parse_partial_tool_call(&tool_indices);
516 calls.extend(result.calls);
517 } else {
518 tracing::warn!("Tool block ended with incomplete tool call");
520 }
521 }
522
523 let remaining = self.buffer[idx + self.eot_token.len()..].to_string();
524 self.buffer.clear();
525 self.tool_block_finished = true;
526
527 self.reset_streaming_state();
529
530 return Ok(StreamingParseResult {
531 normal_text: remaining,
532 calls,
533 });
534 }
535
536 if !self.in_tool_call {
538 if self.buffer.contains(self.tool_call_begin) {
539 let idx = self.buffer.find(self.tool_call_begin).ok_or_else(|| {
541 ParserError::ParsingFailed("token not found in buffer".to_string())
542 })?;
543 self.buffer = self.buffer[idx + self.tool_call_begin.len()..].to_string();
545 self.in_tool_call = true;
546 self.function_name_sent = false;
547 self.current_function_name.clear();
548 self.current_parameters.clear();
549 } else {
551 return Ok(StreamingParseResult::default());
553 }
554 }
555
556 if self.in_tool_call {
558 return Ok(self.parse_partial_tool_call(&tool_indices));
559 }
560
561 Ok(StreamingParseResult::default())
562 }
563
564 fn has_tool_markers(&self, text: &str) -> bool {
565 text.contains(self.bot_token)
566 }
567
568 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
569 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
570 }
571
572 fn reset(&mut self) {
573 self.buffer.clear();
575 self.prev_tool_call_arr.clear();
576 self.current_tool_id = -1;
577 self.streamed_args_for_tool.clear();
578
579 self.in_tool_block = false;
581 self.tool_block_finished = false;
582 self.current_function_name.clear();
583 self.current_parameters.clear();
584 self.in_tool_call = false;
585 self.function_name_sent = false;
586 }
587}