Skip to main content

tool_parser/parsers/
minimax_m2.rs

1use std::{collections::HashMap, fmt::Write as FmtWrite};
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9    errors::{ParserError, ParserResult},
10    parsers::helpers,
11    traits::ToolParser,
12    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15/// MiniMax M2 format parser for tool calls
16///
17/// Handles the MiniMax M2 specific format:
18/// `<minimax:tool_call><invoke name="func"><parameter name="key">value</parameter></invoke></minimax:tool_call>`
19///
20/// Features:
21/// - Namespaced XML tags (`minimax:tool_call`)
22/// - Function wrapped in `<invoke name="...">` tags
23/// - Parameters as `<parameter name="key">value</parameter>`
24/// - Incremental JSON streaming for parameters
25///
26/// Reference: https://huggingface.co/MiniMaxAI/MiniMax-M2?chat_template=default
27pub struct MinimaxM2Parser {
28    // Regex patterns
29    tool_call_extractor: Regex,
30    invoke_extractor: Regex,
31    param_extractor: Regex,
32
33    // Streaming state
34    buffer: String,
35    prev_tool_call_arr: Vec<Value>,
36    current_tool_id: i32,
37    streamed_args_for_tool: Vec<String>,
38    current_function_name: String,
39    current_parameters: HashMap<String, Value>,
40    in_tool_call: bool,
41    function_name_sent: bool,
42    waiting_for_tool_call_end: bool,
43
44    // Token configuration
45    tool_call_start_token: &'static str,
46    tool_call_end_token: &'static str,
47    invoke_end_token: &'static str,
48}
49
50impl MinimaxM2Parser {
51    /// Parse a value from string with consistent logic
52    #[inline]
53    fn parse_value(text: &str) -> Value {
54        // Try parsing as common literals first
55        match text {
56            "true" | "True" => return Value::Bool(true),
57            "false" | "False" => return Value::Bool(false),
58            "null" | "None" => return Value::Null,
59            _ => {}
60        }
61
62        // Try parsing as number
63        if let Ok(num) = text.parse::<i64>() {
64            return Value::Number(num.into());
65        }
66
67        if let Ok(num) = text.parse::<f64>() {
68            if let Some(n) = serde_json::Number::from_f64(num) {
69                return Value::Number(n);
70            }
71        }
72
73        // Default to string
74        Value::String(text.to_string())
75    }
76
77    /// Create a new MiniMax M2 parser
78    #[expect(
79        clippy::expect_used,
80        reason = "regex patterns are compile-time string literals"
81    )]
82    pub fn new() -> Self {
83        // Use (?s) flag for DOTALL mode to handle newlines
84        let tool_call_pattern = r"(?s)<minimax:tool_call>.*?</minimax:tool_call>";
85        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
86
87        let invoke_pattern = r#"(?s)<invoke\s+name="([^"]+)">(.*?)</invoke>"#;
88        let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
89
90        let param_pattern = r#"(?s)<parameter\s+name="([^"]+)">(.*?)</parameter>"#;
91        let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
92
93        Self {
94            tool_call_extractor,
95            invoke_extractor,
96            param_extractor,
97            buffer: String::new(),
98            prev_tool_call_arr: Vec::new(),
99            current_tool_id: -1,
100            streamed_args_for_tool: Vec::new(),
101            current_function_name: String::new(),
102            current_parameters: HashMap::new(),
103            in_tool_call: false,
104            function_name_sent: false,
105            waiting_for_tool_call_end: false,
106            tool_call_start_token: "<minimax:tool_call>",
107            tool_call_end_token: "</minimax:tool_call>",
108            invoke_end_token: "</invoke>",
109        }
110    }
111
112    /// Parse parameters from parameter tags
113    fn parse_parameters(&self, params_text: &str) -> serde_json::Map<String, Value> {
114        let mut parameters = serde_json::Map::new();
115
116        for capture in self.param_extractor.captures_iter(params_text) {
117            let key = capture.get(1).map_or("", |m| m.as_str()).trim();
118            let value_str = capture.get(2).map_or("", |m| m.as_str());
119
120            // Decode XML entities and parse value
121            let decoded_value = Self::decode_xml_entities(value_str);
122
123            // Note: We keep JSON-like strings as strings (not parsed JSON)
124            // This matches the behavior of other parsers like GLM4 MOE
125            let value = Self::parse_value(&decoded_value);
126
127            parameters.insert(key.to_string(), value);
128        }
129
130        parameters
131    }
132
133    /// Decode common XML entities
134    fn decode_xml_entities(text: &str) -> String {
135        text.replace("&lt;", "<")
136            .replace("&gt;", ">")
137            .replace("&amp;", "&")
138            .replace("&quot;", "\"")
139            .replace("&apos;", "'")
140    }
141
142    /// Parse a single tool call block
143    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
144        if let Some(captures) = self.invoke_extractor.captures(block) {
145            // Get function name from invoke tag attribute
146            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
147
148            // Get parameters text
149            let params_text = captures.get(2).map_or("", |m| m.as_str());
150
151            // Parse parameters
152            let parameters = self.parse_parameters(params_text);
153
154            let arguments_str = serde_json::to_string(&parameters)
155                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
156
157            Ok(Some(ToolCall {
158                function: FunctionCall {
159                    name: func_name.to_string(),
160                    arguments: arguments_str,
161                },
162            }))
163        } else {
164            Ok(None)
165        }
166    }
167
168    /// Parse all tool calls from text and return first valid position
169    fn parse_tool_calls_from_text(&self, text: &str) -> (Vec<ToolCall>, Option<usize>) {
170        let mut tools = Vec::new();
171        let mut first_valid_pos = None;
172
173        for mat in self.tool_call_extractor.find_iter(text) {
174            match self.parse_tool_call(mat.as_str()) {
175                Ok(Some(tool)) => {
176                    if first_valid_pos.is_none() {
177                        first_valid_pos = Some(mat.start());
178                    }
179                    tools.push(tool);
180                }
181                Ok(None) => continue,
182                Err(e) => {
183                    tracing::debug!("Failed to parse tool call: {}", e);
184                    continue;
185                }
186            }
187        }
188
189        (tools, first_valid_pos)
190    }
191
192    /// Parse and stream parameters incrementally
193    fn parse_and_stream_parameters(&mut self, text: &str, _tools: &[Tool]) -> Vec<ToolCallItem> {
194        let mut calls = Vec::new();
195
196        // Find all complete parameter patterns in the buffer
197        let param_matches: Vec<_> = self
198            .param_extractor
199            .captures_iter(text)
200            .map(|cap| {
201                let name = cap.get(1).map_or("", |m| m.as_str()).trim().to_string();
202                let value_str = cap.get(2).map_or("", |m| m.as_str());
203                let decoded = Self::decode_xml_entities(value_str);
204
205                // Try parsing as JSON first (for nested objects/arrays)
206                let value = if decoded.starts_with('{') || decoded.starts_with('[') {
207                    if let Ok(json_val) = serde_json::from_str::<Value>(&decoded) {
208                        json_val
209                    } else {
210                        Self::parse_value(&decoded)
211                    }
212                } else {
213                    Self::parse_value(&decoded)
214                };
215
216                (name, value)
217            })
218            .collect();
219
220        // Build new parameters map
221        let mut new_params = HashMap::new();
222        for (name, value) in param_matches {
223            new_params.insert(name, value);
224        }
225
226        // If we have new parameters that weren't in current_parameters, stream them
227        if !new_params.is_empty() && new_params != self.current_parameters {
228            let tool_id = self.current_tool_id as usize;
229
230            // Ensure we have enough capacity
231            while self.streamed_args_for_tool.len() <= tool_id {
232                self.streamed_args_for_tool.push(String::new());
233            }
234
235            // Build incremental JSON with single allocation
236            if self.current_parameters.is_empty() {
237                // First parameters - start JSON object but don't close it
238                let mut json_fragment = String::with_capacity(256);
239                json_fragment.push('{');
240
241                let mut first = true;
242                for (key, value) in &new_params {
243                    if !first {
244                        json_fragment.push_str(", ");
245                    }
246                    // serde_json::to_string for String/Value is infallible; write! to String is infallible
247                    let key_json = serde_json::to_string(key).unwrap_or_default();
248                    let value_json = serde_json::to_string(value).unwrap_or_default();
249                    let _ = write!(&mut json_fragment, "{key_json}: {value_json}");
250                    first = false;
251                }
252
253                calls.push(ToolCallItem {
254                    tool_index: tool_id,
255                    name: None,
256                    parameters: json_fragment.clone(),
257                });
258
259                self.streamed_args_for_tool[tool_id] = json_fragment;
260            } else {
261                // Additional parameters - add them incrementally
262                let new_keys: Vec<_> = new_params
263                    .keys()
264                    .filter(|k| !self.current_parameters.contains_key(*k))
265                    .collect();
266
267                if !new_keys.is_empty() {
268                    let mut json_fragment = String::with_capacity(128);
269
270                    for key in new_keys {
271                        let value = &new_params[key];
272                        // serde_json::to_string for String/Value is infallible; write! to String is infallible
273                        let key_json = serde_json::to_string(key).unwrap_or_default();
274                        let value_json = serde_json::to_string(value).unwrap_or_default();
275                        let _ = write!(&mut json_fragment, ", {key_json}: {value_json}");
276                    }
277
278                    calls.push(ToolCallItem {
279                        tool_index: tool_id,
280                        name: None,
281                        parameters: json_fragment.clone(),
282                    });
283
284                    self.streamed_args_for_tool[tool_id].push_str(&json_fragment);
285                }
286            }
287
288            // Update current parameters
289            self.current_parameters = new_params;
290
291            // Update prev_tool_call_arr
292            while self.prev_tool_call_arr.len() <= tool_id {
293                self.prev_tool_call_arr.push(Value::Null);
294            }
295            self.prev_tool_call_arr[tool_id] = serde_json::json!({
296                "name": self.current_function_name,
297                "arguments": self.current_parameters,
298            });
299        }
300
301        calls
302    }
303}
304
305impl Default for MinimaxM2Parser {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311#[async_trait]
312impl ToolParser for MinimaxM2Parser {
313    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
314        // Check if text contains MiniMax M2 format
315        if !self.has_tool_markers(text) {
316            return Ok((text.to_string(), vec![]));
317        }
318
319        // Parse all tool calls and get first valid position
320        let (tools, first_valid_tool_pos) = self.parse_tool_calls_from_text(text);
321
322        // If no tools were successfully parsed, return entire text as fallback
323        if tools.is_empty() {
324            return Ok((text.to_string(), vec![]));
325        }
326
327        // Determine what text to return as normal_text
328        let normal_text = if let Some(pos) = first_valid_tool_pos {
329            // Return text up to the first valid tool call
330            text[..pos].to_string()
331        } else {
332            // No valid tool calls found, return entire text
333            text.to_string()
334        };
335
336        Ok((normal_text, tools))
337    }
338
339    async fn parse_incremental(
340        &mut self,
341        chunk: &str,
342        tools: &[Tool],
343    ) -> ParserResult<StreamingParseResult> {
344        self.buffer.push_str(chunk);
345        let mut normal_text = String::new();
346        let mut calls = Vec::new();
347
348        // Build tool indices for validation
349        let tool_indices = helpers::get_tool_indices(tools);
350
351        loop {
352            // If we're waiting for the tool call end tag, check for it first
353            if self.waiting_for_tool_call_end {
354                if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
355                    // Complete tool call found
356                    self.buffer =
357                        self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
358                    self.in_tool_call = false;
359                    self.waiting_for_tool_call_end = false;
360                    self.function_name_sent = false;
361                    self.current_function_name.clear();
362                    self.current_parameters.clear();
363                    self.current_tool_id += 1;
364                    continue;
365                } else {
366                    // End tag not complete yet, wait for more text
367                    break;
368                }
369            }
370
371            // If we're not in a tool call and don't see a start token, return normal text
372            if !self.in_tool_call && !self.buffer.contains(self.tool_call_start_token) {
373                // Check if buffer might contain a partial start token at the end
374                if let Some(partial_len) =
375                    helpers::ends_with_partial_token(&self.buffer, self.tool_call_start_token)
376                {
377                    // Return everything except the potential partial token
378                    let end = self.buffer.len() - partial_len;
379                    normal_text = self.buffer[..end].to_string();
380                    self.buffer = self.buffer[end..].to_string();
381                } else {
382                    // No partial token, return all as normal text
383                    normal_text.clone_from(&self.buffer);
384                    self.buffer.clear();
385                }
386                break;
387            }
388
389            // Look for tool call start
390            if !self.in_tool_call {
391                if let Some(start) = self.buffer.find(self.tool_call_start_token) {
392                    normal_text = self.buffer[..start].to_string();
393                    self.buffer =
394                        self.buffer[start + self.tool_call_start_token.len()..].to_string();
395
396                    self.in_tool_call = true;
397                    self.function_name_sent = false;
398                    self.current_function_name.clear();
399                    self.current_parameters.clear();
400
401                    continue;
402                } else {
403                    // No start token found
404                    break;
405                }
406            }
407
408            // We're in a tool call, try to parse function name if not sent yet
409            if !self.function_name_sent {
410                // Use regex to extract function name from <invoke name="..."> pattern
411                // Check if we have enough text to match the invoke pattern
412                if let Some(captures) = self.invoke_extractor.captures(&self.buffer) {
413                    let function_name = captures
414                        .get(1)
415                        .map_or("", |m| m.as_str())
416                        .trim()
417                        .to_string();
418
419                    // Validate function name
420                    if tool_indices.contains_key(&function_name) {
421                        self.current_function_name.clone_from(&function_name);
422                        self.function_name_sent = true;
423
424                        // Initialize tool call tracking
425                        if self.current_tool_id == -1 {
426                            self.current_tool_id = 0;
427                        }
428
429                        // Ensure tracking arrays are large enough
430                        helpers::ensure_capacity(
431                            self.current_tool_id,
432                            &mut self.prev_tool_call_arr,
433                            &mut self.streamed_args_for_tool,
434                        );
435
436                        // Send tool name with empty parameters
437                        calls.push(ToolCallItem {
438                            tool_index: self.current_tool_id as usize,
439                            name: Some(function_name),
440                            parameters: String::new(),
441                        });
442
443                        // Find the position after the opening invoke tag (after the >)
444                        // We only want to remove up to the opening tag, not the full match
445                        if let Some(pos) = self.buffer.find('>') {
446                            self.buffer = self.buffer[pos + 1..].to_string();
447                        }
448                        continue;
449                    } else {
450                        // Invalid function name, reset state
451                        tracing::debug!("Invalid function name: {}", function_name);
452                        self.in_tool_call = false;
453                        normal_text.push_str(&self.buffer);
454                        self.buffer.clear();
455                        break;
456                    }
457                }
458                // No complete invoke pattern found yet, wait for more text
459                break;
460            }
461
462            // Parse parameters incrementally
463            if self.function_name_sent {
464                // Process parameters and get any calls to emit
465                // Note: We need to be careful here - parse_and_stream_parameters needs
466                // to work with the buffer but we can't pass &self.buffer directly
467                // due to borrow checker. Instead, we'll refactor slightly.
468                // For now, keep the clone but mark it as a TODO for future optimization
469                let buffer_copy = self.buffer.clone(); // TODO: Optimize this
470                let parameter_calls = self.parse_and_stream_parameters(&buffer_copy, tools);
471                calls.extend(parameter_calls);
472
473                // Check if tool call is complete (</invoke> found)
474                if let Some(invoke_end) = self.buffer.find(self.invoke_end_token) {
475                    // Add closing brace to complete the JSON object
476                    let tool_id = self.current_tool_id as usize;
477                    if tool_id < self.streamed_args_for_tool.len() {
478                        let current_streamed = &self.streamed_args_for_tool[tool_id];
479                        if !current_streamed.is_empty() && !current_streamed.ends_with('}') {
480                            // Count opening and closing braces to check if JSON is complete
481                            let open_braces = current_streamed.matches('{').count();
482                            let close_braces = current_streamed.matches('}').count();
483                            if open_braces > close_braces {
484                                calls.push(ToolCallItem {
485                                    tool_index: tool_id,
486                                    name: None,
487                                    parameters: "}".to_string(),
488                                });
489                                self.streamed_args_for_tool[tool_id].push('}');
490                            }
491                        }
492                    }
493
494                    // Move buffer past the </invoke>
495                    self.buffer =
496                        self.buffer[invoke_end + self.invoke_end_token.len()..].to_string();
497
498                    // Check if we have the closing </minimax:tool_call>
499                    if let Some(end_pos) = self.buffer.find(self.tool_call_end_token) {
500                        // Complete tool call found
501                        self.buffer =
502                            self.buffer[end_pos + self.tool_call_end_token.len()..].to_string();
503                        self.in_tool_call = false;
504                        self.function_name_sent = false;
505                        self.current_function_name.clear();
506                        self.current_parameters.clear();
507                        self.current_tool_id += 1;
508                        continue;
509                    } else {
510                        // End tag not complete yet, mark that we're waiting for it
511                        self.waiting_for_tool_call_end = true;
512                        break;
513                    }
514                }
515                // Tool call not complete yet, wait for more text
516                break;
517            }
518        }
519
520        Ok(StreamingParseResult { normal_text, calls })
521    }
522
523    fn has_tool_markers(&self, text: &str) -> bool {
524        text.contains(self.tool_call_start_token)
525    }
526
527    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
528        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
529    }
530
531    fn reset(&mut self) {
532        self.buffer.clear();
533        self.prev_tool_call_arr.clear();
534        self.current_tool_id = -1;
535        self.streamed_args_for_tool.clear();
536        self.current_function_name.clear();
537        self.current_parameters.clear();
538        self.in_tool_call = false;
539        self.function_name_sent = false;
540        self.waiting_for_tool_call_end = false;
541    }
542}