Skip to main content

tool_parser/parsers/
minimax_m2.rs

1use std::{collections::HashMap, fmt::Write as FmtWrite};
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9    errors::{ParserError, ParserResult},
10    parsers::helpers,
11    traits::ToolParser,
12    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15/// MiniMax M2 format parser for tool calls
16///
17/// Handles the MiniMax M2 specific format:
18/// `<minimax:tool_call><invoke name="func"><parameter name="key">value</parameter></invoke></minimax:tool_call>`
19///
20/// Features:
21/// - Namespaced XML tags (`minimax:tool_call`)
22/// - Function wrapped in `<invoke name="...">` tags
23/// - Parameters as `<parameter name="key">value</parameter>`
24/// - Incremental JSON streaming for parameters
25///
26/// Reference: https://huggingface.co/MiniMaxAI/MiniMax-M2?chat_template=default
27pub struct MinimaxM2Parser {
28    // Regex patterns
29    tool_call_extractor: Regex,
30    invoke_extractor: Regex,
31    param_extractor: Regex,
32
33    // Streaming state
34    buffer: String,
35    prev_tool_call_arr: Vec<Value>,
36    current_tool_id: i32,
37    streamed_args_for_tool: Vec<String>,
38    current_function_name: String,
39    current_parameters: HashMap<String, Value>,
40    in_tool_call: bool,
41    function_name_sent: bool,
42    waiting_for_tool_call_end: bool,
43
44    // Token configuration
45    tool_call_start_token: &'static str,
46    tool_call_end_token: &'static str,
47    invoke_end_token: &'static str,
48}
49
50impl MinimaxM2Parser {
51    /// Parse a value from string with consistent logic
52    #[inline]
53    fn parse_value(text: &str) -> Value {
54        // Try parsing as common literals first
55        match text {
56            "true" | "True" => return Value::Bool(true),
57            "false" | "False" => return Value::Bool(false),
58            "null" | "None" => return Value::Null,
59            _ => {}
60        }
61
62        // Try parsing as number
63        if let Ok(num) = text.parse::<i64>() {
64            return Value::Number(num.into());
65        }
66
67        if let Ok(num) = text.parse::<f64>() {
68            if let Some(n) = serde_json::Number::from_f64(num) {
69                return Value::Number(n);
70            }
71        }
72
73        // Default to string
74        Value::String(text.to_string())
75    }
76
77    /// Create a new MiniMax M2 parser
78    pub fn new() -> Self {
79        // Use (?s) flag for DOTALL mode to handle newlines
80        let tool_call_pattern = r"(?s)<minimax:tool_call>.*?</minimax:tool_call>";
81        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
82
83        let invoke_pattern = r#"(?s)<invoke\s+name="([^"]+)">(.*?)</invoke>"#;
84        let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
85
86        let param_pattern = r#"(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>"#;
87        let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
88
89        Self {
90            tool_call_extractor,
91            invoke_extractor,
92            param_extractor,
93            buffer: String::new(),
94            prev_tool_call_arr: Vec::new(),
95            current_tool_id: -1,
96            streamed_args_for_tool: Vec::new(),
97            current_function_name: String::new(),
98            current_parameters: HashMap::new(),
99            in_tool_call: false,
100            function_name_sent: false,
101            waiting_for_tool_call_end: false,
102            tool_call_start_token: "<minimax:tool_call>",
103            tool_call_end_token: "</minimax:tool_call>",
104            invoke_end_token: "</invoke>",
105        }
106    }
107
108    /// Parse parameters from parameter tags
109    fn parse_parameters(&self, params_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
110        let mut parameters = serde_json::Map::new();
111
112        for capture in self.param_extractor.captures_iter(params_text) {
113            let key = capture.get(1).map_or("", |m| m.as_str()).trim();
114            let value_str = capture.get(2).map_or("", |m| m.as_str());
115
116            // Decode XML entities and parse value
117            let decoded_value = self.decode_xml_entities(value_str);
118
119            // Note: We keep JSON-like strings as strings (not parsed JSON)
120            // This matches the behavior of other parsers like GLM4 MOE
121            let value = Self::parse_value(&decoded_value);
122
123            parameters.insert(key.to_string(), value);
124        }
125
126        Ok(parameters)
127    }
128
129    /// Decode common XML entities
130    fn decode_xml_entities(&self, text: &str) -> String {
131        text.replace("&lt;", "<")
132            .replace("&gt;", ">")
133            .replace("&amp;", "&")
134            .replace("&quot;", "\"")
135            .replace("&apos;", "'")
136    }
137
138    /// Parse a single tool call block
139    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
140        if let Some(captures) = self.invoke_extractor.captures(block) {
141            // Get function name from invoke tag attribute
142            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
143
144            // Get parameters text
145            let params_text = captures.get(2).map_or("", |m| m.as_str());
146
147            // Parse parameters
148            let parameters = self.parse_parameters(params_text)?;
149
150            let arguments_str = serde_json::to_string(&parameters)
151                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
152
153            Ok(Some(ToolCall {
154                function: FunctionCall {
155                    name: func_name.to_string(),
156                    arguments: arguments_str,
157                },
158            }))
159        } else {
160            Ok(None)
161        }
162    }
163
164    /// Parse all tool calls from text and return first valid position
165    fn parse_tool_calls_from_text(
166        &self,
167        text: &str,
168    ) -> ParserResult<(Vec<ToolCall>, Option<usize>)> {
169        let mut tools = Vec::new();
170        let mut first_valid_pos = None;
171
172        for mat in self.tool_call_extractor.find_iter(text) {
173            match self.parse_tool_call(mat.as_str()) {
174                Ok(Some(tool)) => {
175                    if first_valid_pos.is_none() {
176                        first_valid_pos = Some(mat.start());
177                    }
178                    tools.push(tool);
179                }
180                Ok(None) => continue,
181                Err(e) => {
182                    tracing::debug!("Failed to parse tool call: {}", e);
183                    continue;
184                }
185            }
186        }
187
188        Ok((tools, first_valid_pos))
189    }
190
191    /// Parse and stream parameters incrementally
192    fn parse_and_stream_parameters(&mut self, text: &str, _tools: &[Tool]) -> Vec<ToolCallItem> {
193        let mut calls = Vec::new();
194
195        // Find all complete parameter patterns in the buffer
196        let param_matches: Vec<_> = self
197            .param_extractor
198            .captures_iter(text)
199            .map(|cap| {
200                let name = cap.get(1).map_or("", |m| m.as_str()).trim().to_string();
201                let value_str = cap.get(2).map_or("", |m| m.as_str());
202                let decoded = self.decode_xml_entities(value_str);
203
204                // Try parsing as JSON first (for nested objects/arrays)
205                let value = if decoded.starts_with('{') || decoded.starts_with('[') {
206                    if let Ok(json_val) = serde_json::from_str::<Value>(&decoded) {
207                        json_val
208                    } else {
209                        Self::parse_value(&decoded)
210                    }
211                } else {
212                    Self::parse_value(&decoded)
213                };
214
215                (name, value)
216            })
217            .collect();
218
219        // Build new parameters map
220        let mut new_params = HashMap::new();
221        for (name, value) in param_matches {
222            new_params.insert(name, value);
223        }
224
225        // If we have new parameters that weren't in current_parameters, stream them
226        if !new_params.is_empty() && new_params != self.current_parameters {
227            let tool_id = self.current_tool_id as usize;
228
229            // Ensure we have enough capacity
230            while self.streamed_args_for_tool.len() <= tool_id {
231                self.streamed_args_for_tool.push(String::new());
232            }
233
234            // Build incremental JSON with single allocation
235            if self.current_parameters.is_empty() {
236                // First parameters - start JSON object but don't close it
237                let mut json_fragment = String::with_capacity(256);
238                json_fragment.push('{');
239
240                let mut first = true;
241                for (key, value) in &new_params {
242                    if !first {
243                        json_fragment.push_str(", ");
244                    }
245                    write!(
246                        &mut json_fragment,
247                        "{}: {}",
248                        serde_json::to_string(key).unwrap(),
249                        serde_json::to_string(value).unwrap()
250                    )
251                    .unwrap();
252                    first = false;
253                }
254
255                calls.push(ToolCallItem {
256                    tool_index: tool_id,
257                    name: None,
258                    parameters: json_fragment.clone(),
259                });
260
261                self.streamed_args_for_tool[tool_id] = json_fragment;
262            } else {
263                // Additional parameters - add them incrementally
264                let new_keys: Vec<_> = new_params
265                    .keys()
266                    .filter(|k| !self.current_parameters.contains_key(*k))
267                    .collect();
268
269                if !new_keys.is_empty() {
270                    let mut json_fragment = String::with_capacity(128);
271
272                    for key in new_keys {
273                        let value = &new_params[key];
274                        write!(
275                            &mut json_fragment,
276                            ", {}: {}",
277                            serde_json::to_string(key).unwrap(),
278                            serde_json::to_string(value).unwrap()
279                        )
280                        .unwrap();
281                    }
282
283                    calls.push(ToolCallItem {
284                        tool_index: tool_id,
285                        name: None,
286                        parameters: json_fragment.clone(),
287                    });
288
289                    self.streamed_args_for_tool[tool_id].push_str(&json_fragment);
290                }
291            }
292
293            // Update current parameters
294            self.current_parameters = new_params;
295
296            // Update prev_tool_call_arr
297            while self.prev_tool_call_arr.len() <= tool_id {
298                self.prev_tool_call_arr.push(Value::Null);
299            }
300            self.prev_tool_call_arr[tool_id] = serde_json::json!({
301                "name": self.current_function_name,
302                "arguments": self.current_parameters,
303            });
304        }
305
306        calls
307    }
308}
309
310impl Default for MinimaxM2Parser {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316#[async_trait]
317impl ToolParser for MinimaxM2Parser {
318    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
319        // Check if text contains MiniMax M2 format
320        if !self.has_tool_markers(text) {
321            return Ok((text.to_string(), vec![]));
322        }
323
324        // Parse all tool calls and get first valid position
325        let (tools, first_valid_tool_pos) = self.parse_tool_calls_from_text(text)?;
326
327        // If no tools were successfully parsed, return entire text as fallback
328        if tools.is_empty() {
329            return Ok((text.to_string(), vec![]));
330        }
331
332        // Determine what text to return as normal_text
333        let normal_text = if let Some(pos) = first_valid_tool_pos {
334            // Return text up to the first valid tool call
335            text[..pos].to_string()
336        } else {
337            // No valid tool calls found, return entire text
338            text.to_string()
339        };
340
341        Ok((normal_text, tools))
342    }
343
344    async fn parse_incremental(
345        &mut self,
346        chunk: &str,
347        tools: &[Tool],
348    ) -> ParserResult<StreamingParseResult> {
349        self.buffer.push_str(chunk);
350        let mut normal_text = String::new();
351        let mut calls = Vec::new();
352
353        // Build tool indices for validation
354        let tool_indices = helpers::get_tool_indices(tools);
355
356        loop {
357            // If we're waiting for the tool call end tag, check for it first
358            if self.waiting_for_tool_call_end {
359                if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
360                    // Complete tool call found
361                    self.buffer =
362                        self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
363                    self.in_tool_call = false;
364                    self.waiting_for_tool_call_end = false;
365                    self.function_name_sent = false;
366                    self.current_function_name.clear();
367                    self.current_parameters.clear();
368                    self.current_tool_id += 1;
369                    continue;
370                } else {
371                    // End tag not complete yet, wait for more text
372                    break;
373                }
374            }
375
376            // If we're not in a tool call and don't see a start token, return normal text
377            if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
378                // Check if buffer might contain a partial start token at the end
379                if let Some(partial_len) =
380                    helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
381                {
382                    // Return everything except the potential partial token
383                    let end = self.buffer.len() - partial_len;
384                    normal_text = self.buffer[..end].to_string();
385                    self.buffer = self.buffer[end..].to_string();
386                } else {
387                    // No partial token, return all as normal text
388                    normal_text = self.buffer.clone();
389                    self.buffer.clear();
390                }
391                break;
392            }
393
394            // Look for tool call start
395            if !self.in_tool_call {
396                if let Some(start) = self.buffer.find(self.tool_call_start_token) {
397                    normal_text = self.buffer[..start].to_string();
398                    self.buffer =
399                        self.buffer[start + self.tool_call_start_token.len()..].to_string();
400
401                    self.in_tool_call = true;
402                    self.function_name_sent = false;
403                    self.current_function_name.clear();
404                    self.current_parameters.clear();
405
406                    continue;
407                } else {
408                    // No start token found
409                    break;
410                }
411            }
412
413            // We're in a tool call, try to parse function name if not sent yet
414            if !self.function_name_sent {
415                // Use regex to extract function name from <invoke name="..."> pattern
416                // Check if we have enough text to match the invoke pattern
417                if let Some(captures) = self.invoke_extractor.captures(&self.buffer) {
418                    let function_name = captures
419                        .get(1)
420                        .map_or("", |m| m.as_str())
421                        .trim()
422                        .to_string();
423
424                    // Validate function name
425                    if tool_indices.contains_key(&function_name) {
426                        self.current_function_name = function_name.clone();
427                        self.function_name_sent = true;
428
429                        // Initialize tool call tracking
430                        if self.current_tool_id == -1 {
431                            self.current_tool_id = 0;
432                        }
433
434                        // Ensure tracking arrays are large enough
435                        helpers::ensure_capacity(
436                            self.current_tool_id,
437                            &mut self.prev_tool_call_arr,
438                            &mut self.streamed_args_for_tool,
439                        );
440
441                        // Send tool name with empty parameters
442                        calls.push(ToolCallItem {
443                            tool_index: self.current_tool_id as usize,
444                            name: Some(function_name),
445                            parameters: String::new(),
446                        });
447
448                        // Find the position after the opening invoke tag (after the >)
449                        // We only want to remove up to the opening tag, not the full match
450                        if let Some(pos) = self.buffer.find('>') {
451                            self.buffer = self.buffer[pos + 1..].to_string();
452                        }
453                        continue;
454                    } else {
455                        // Invalid function name, reset state
456                        tracing::debug!("Invalid function name: {}", function_name);
457                        self.in_tool_call = false;
458                        normal_text.push_str(&self.buffer);
459                        self.buffer.clear();
460                        break;
461                    }
462                }
463                // No complete invoke pattern found yet, wait for more text
464                break;
465            }
466
467            // Parse parameters incrementally
468            if self.function_name_sent {
469                // Process parameters and get any calls to emit
470                // Note: We need to be careful here - parse_and_stream_parameters needs
471                // to work with the buffer but we can't pass &self.buffer directly
472                // due to borrow checker. Instead, we'll refactor slightly.
473                // For now, keep the clone but mark it as a TODO for future optimization
474                let buffer_copy = self.buffer.clone(); // TODO: Optimize this
475                let parameter_calls = self.parse_and_stream_parameters(&buffer_copy, tools);
476                calls.extend(parameter_calls);
477
478                // Check if tool call is complete (</invoke> found)
479                if let Some(invoke_end) = self.buffer.find(self.invoke_end_token) {
480                    // Add closing brace to complete the JSON object
481                    let tool_id = self.current_tool_id as usize;
482                    if tool_id < self.streamed_args_for_tool.len() {
483                        let current_streamed = &self.streamed_args_for_tool[tool_id];
484                        if !current_streamed.is_empty() && !current_streamed.ends_with('}') {
485                            // Count opening and closing braces to check if JSON is complete
486                            let open_braces = current_streamed.matches('{').count();
487                            let close_braces = current_streamed.matches('}').count();
488                            if open_braces > close_braces {
489                                calls.push(ToolCallItem {
490                                    tool_index: tool_id,
491                                    name: None,
492                                    parameters: "}".to_string(),
493                                });
494                                self.streamed_args_for_tool[tool_id].push('}');
495                            }
496                        }
497                    }
498
499                    // Move buffer past the </invoke>
500                    self.buffer =
501                        self.buffer[invoke_end + self.invoke_end_token.len()..].to_string();
502
503                    // Check if we have the closing </minimax:tool_call>
504                    if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
505                        // Complete tool call found
506                        self.buffer =
507                            self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
508                        self.in_tool_call = false;
509                        self.function_name_sent = false;
510                        self.current_function_name.clear();
511                        self.current_parameters.clear();
512                        self.current_tool_id += 1;
513                        continue;
514                    } else {
515                        // End tag not complete yet, mark that we're waiting for it
516                        self.waiting_for_tool_call_end = true;
517                        break;
518                    }
519                }
520                // Tool call not complete yet, wait for more text
521                break;
522            }
523        }
524
525        Ok(StreamingParseResult { normal_text, calls })
526    }
527
528    fn has_tool_markers(&self, text: &str) -> bool {
529        text.contains(self.tool_call_start_token)
530    }
531
532    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
533        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
534    }
535
536    fn reset(&mut self) {
537        self.buffer.clear();
538        self.prev_tool_call_arr.clear();
539        self.current_tool_id = -1;
540        self.streamed_args_for_tool.clear();
541        self.current_function_name.clear();
542        self.current_parameters.clear();
543        self.in_tool_call = false;
544        self.function_name_sent = false;
545        self.waiting_for_tool_call_end = false;
546    }
547}