Skip to main content

tool_parser/parsers/
glm4_moe.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// GLM-4 MoE format parser for tool calls
14///
15/// Handles both GLM-4 MoE and GLM-4.7 MoE formats:
16/// - GLM-4: `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
17/// - GLM-4.7: `<tool_call>{name}<arg_key>{key}</arg_key><arg_value>{value}</arg_value></tool_call>`
18///
19/// Features:
20/// - XML-style tags for tool calls
21/// - Key-value pairs for arguments
22/// - Support for multiple sequential tool calls
23pub struct Glm4MoeParser {
24    /// Regex for extracting complete tool calls
25    tool_call_extractor: Regex,
26    /// Regex for extracting function details
27    func_detail_extractor: Regex,
28    /// Regex for extracting argument key-value pairs
29    arg_extractor: Regex,
30
31    /// Buffer for accumulating incomplete patterns across chunks
32    buffer: String,
33
34    /// Stores complete tool call info (name and arguments) for each tool being parsed
35    prev_tool_call_arr: Vec<Value>,
36
37    /// Index of currently streaming tool call (-1 means no active tool)
38    current_tool_id: i32,
39
40    /// Tracks raw JSON string content streamed to client for each tool's arguments
41    streamed_args_for_tool: Vec<String>,
42
43    /// Token configuration
44    bot_token: &'static str,
45    eot_token: &'static str,
46}
47
48impl Glm4MoeParser {
49    /// Create a new generic GLM MoE parser with a custom func_detail_extractor pattern
50    ///
51    /// # Arguments
52    /// - `func_detail_pattern`: Regex pattern for extracting function name and arguments
53    ///   - For GLM-4: `r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>"`
54    ///   - For GLM-4.7: `r"(?s)<tool_call>\s*([^<\s]+)\s*(.*?)</tool_call>"`
55    #[expect(
56        clippy::expect_used,
57        reason = "regex patterns are compile-time string literals"
58    )]
59    pub(crate) fn new(func_detail_pattern: &str) -> Self {
60        // Use (?s) flag for DOTALL mode to handle newlines
61        let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
62        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
63
64        let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
65
66        let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
67        let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
68
69        Self {
70            tool_call_extractor,
71            func_detail_extractor,
72            arg_extractor,
73            buffer: String::new(),
74            prev_tool_call_arr: Vec::new(),
75            current_tool_id: -1,
76            streamed_args_for_tool: Vec::new(),
77            bot_token: "<tool_call>",
78            eot_token: "</tool_call>",
79        }
80    }
81
82    /// Create a new GLM-4.5/4.6 MoE parser (with newline-based format)
83    pub fn glm45() -> Self {
84        Self::new(r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>")
85    }
86
87    /// Create a new GLM-4.7 MoE parser (with whitespace-based format)
88    pub fn glm47() -> Self {
89        Self::new(r"(?s)<tool_call>\s*([^<\s]+)\s*(.*?)</tool_call>")
90    }
91
92    /// Parse arguments from key-value pairs
93    fn parse_arguments(&self, args_text: &str) -> serde_json::Map<String, Value> {
94        let mut arguments = serde_json::Map::new();
95
96        for capture in self.arg_extractor.captures_iter(args_text) {
97            let key = capture.get(1).map_or("", |m| m.as_str()).trim();
98            let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
99
100            // Try to parse the value as JSON first, fallback to string
101            let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
102                json_val
103            } else {
104                // Try parsing as Python literal (similar to Python's ast.literal_eval)
105                if value_str == "true" || value_str == "True" {
106                    Value::Bool(true)
107                } else if value_str == "false" || value_str == "False" {
108                    Value::Bool(false)
109                } else if value_str == "null" || value_str == "None" {
110                    Value::Null
111                } else if let Ok(num) = value_str.parse::<i64>() {
112                    Value::Number(num.into())
113                } else if let Ok(num) = value_str.parse::<f64>() {
114                    if let Some(n) = serde_json::Number::from_f64(num) {
115                        Value::Number(n)
116                    } else {
117                        Value::String(value_str.to_string())
118                    }
119                } else {
120                    Value::String(value_str.to_string())
121                }
122            };
123
124            arguments.insert(key.to_string(), value);
125        }
126
127        arguments
128    }
129
130    /// Parse a single tool call block
131    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
132        if let Some(captures) = self.func_detail_extractor.captures(block) {
133            // Get function name
134            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
135
136            // Get arguments text
137            let args_text = captures.get(2).map_or("", |m| m.as_str());
138
139            // Parse arguments
140            let arguments = self.parse_arguments(args_text);
141
142            let arguments_str = serde_json::to_string(&arguments)
143                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
144
145            Ok(Some(ToolCall {
146                function: FunctionCall {
147                    name: func_name.to_string(),
148                    arguments: arguments_str,
149                },
150            }))
151        } else {
152            Ok(None)
153        }
154    }
155
156    /// Parse all tool calls from text (shared logic for complete and incremental parsing)
157    fn parse_tool_calls_from_text(&self, text: &str) -> Vec<ToolCall> {
158        let mut tools = Vec::new();
159
160        for mat in self.tool_call_extractor.find_iter(text) {
161            match self.parse_tool_call(mat.as_str()) {
162                Ok(Some(tool)) => tools.push(tool),
163                Ok(None) => continue,
164                Err(e) => {
165                    tracing::debug!("Failed to parse tool call: {}", e);
166                    continue;
167                }
168            }
169        }
170
171        tools
172    }
173}
174
175impl Default for Glm4MoeParser {
176    fn default() -> Self {
177        Self::glm45()
178    }
179}
180
181#[async_trait]
182impl ToolParser for Glm4MoeParser {
183    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
184        // Check if text contains GLM-4 MoE format
185        if !self.has_tool_markers(text) {
186            return Ok((text.to_string(), vec![]));
187        }
188
189        // Find where tool calls begin
190        // Safe: has_tool_markers() already confirmed the marker exists
191        let idx = text
192            .find("<tool_call>")
193            .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
194        let normal_text = text[..idx].to_string();
195
196        // Parse all tool calls using shared helper
197        let tools = self.parse_tool_calls_from_text(text);
198
199        // If no tools were successfully parsed despite having markers, return entire text as fallback
200        if tools.is_empty() {
201            return Ok((text.to_string(), vec![]));
202        }
203
204        Ok((normal_text, tools))
205    }
206
207    async fn parse_incremental(
208        &mut self,
209        chunk: &str,
210        tools: &[Tool],
211    ) -> ParserResult<StreamingParseResult> {
212        // Python logic: Wait for complete tool call, then parse it all at once
213        self.buffer.push_str(chunk);
214        let current_text = &self.buffer.clone();
215
216        // Check if we have bot_token
217        let start = current_text.find(self.bot_token);
218        if start.is_none() {
219            self.buffer.clear();
220            // If we're in the middle of streaming (current_tool_id > 0), don't return text
221            let normal_text = if self.current_tool_id > 0 {
222                String::new()
223            } else {
224                current_text.clone()
225            };
226            return Ok(StreamingParseResult {
227                normal_text,
228                calls: vec![],
229            });
230        }
231
232        // Check if we have eot_token (end of tool call)
233        let end = current_text.find(self.eot_token);
234        if let Some(end_pos) = end {
235            // We have a complete tool call!
236
237            // Initialize state if this is the first tool call
238            if self.current_tool_id == -1 {
239                self.current_tool_id = 0;
240                self.prev_tool_call_arr = Vec::new();
241                self.streamed_args_for_tool = vec![String::new()];
242            }
243
244            // Ensure we have enough entries in our tracking arrays
245            helpers::ensure_capacity(
246                self.current_tool_id,
247                &mut self.prev_tool_call_arr,
248                &mut self.streamed_args_for_tool,
249            );
250
251            // Parse the complete block using shared helper
252            let block_end = end_pos + self.eot_token.len();
253            let parsed_tools = self.parse_tool_calls_from_text(&current_text[..block_end]);
254
255            // Extract normal text before tool calls
256            let idx = current_text.find(self.bot_token);
257            let normal_text = if let Some(pos) = idx {
258                current_text[..pos].trim().to_string()
259            } else {
260                String::new()
261            };
262
263            // Build tool indices for validation
264            let tool_indices = helpers::get_tool_indices(tools);
265
266            let mut calls = Vec::new();
267
268            if !parsed_tools.is_empty() {
269                // Take the first tool and convert to ToolCallItem
270                let tool_call = &parsed_tools[0];
271                let tool_id = self.current_tool_id as usize;
272
273                // Validate tool name
274                if !tool_indices.contains_key(&tool_call.function.name) {
275                    // Invalid tool name - skip this tool, preserve indexing for next tool
276                    tracing::debug!("Invalid tool name '{}' - skipping", tool_call.function.name);
277                    helpers::reset_current_tool_state(
278                        &mut self.buffer,
279                        &mut false, // glm45_moe/glm47_moe doesn't track name_sent per tool
280                        &mut self.streamed_args_for_tool,
281                        &self.prev_tool_call_arr,
282                    );
283                    return Ok(StreamingParseResult::default());
284                }
285
286                calls.push(ToolCallItem {
287                    tool_index: tool_id,
288                    name: Some(tool_call.function.name.clone()),
289                    parameters: tool_call.function.arguments.clone(),
290                });
291
292                // Store in tracking arrays
293                if self.prev_tool_call_arr.len() <= tool_id {
294                    self.prev_tool_call_arr
295                        .resize_with(tool_id + 1, || Value::Null);
296                }
297
298                // Parse parameters as JSON and store
299                if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
300                    self.prev_tool_call_arr[tool_id] = serde_json::json!({
301                        "name": tool_call.function.name,
302                        "arguments": args,
303                    });
304                }
305
306                if self.streamed_args_for_tool.len() <= tool_id {
307                    self.streamed_args_for_tool
308                        .resize_with(tool_id + 1, String::new);
309                }
310                self.streamed_args_for_tool[tool_id].clone_from(&tool_call.function.arguments);
311
312                self.current_tool_id += 1;
313            }
314
315            // Remove processed portion from buffer
316            self.buffer = current_text[block_end..].to_string();
317            return Ok(StreamingParseResult { normal_text, calls });
318        }
319
320        // No complete tool call yet - return normal text before start token
321        // Safe: start.is_none() case was handled above (early return)
322        let Some(start_pos) = start else {
323            return Ok(StreamingParseResult::default());
324        };
325        let normal_text = current_text[..start_pos].to_string();
326        self.buffer = current_text[start_pos..].to_string();
327
328        Ok(StreamingParseResult {
329            normal_text,
330            calls: vec![],
331        })
332    }
333
334    fn has_tool_markers(&self, text: &str) -> bool {
335        text.contains(self.bot_token)
336    }
337
338    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
339        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
340    }
341
342    fn reset(&mut self) {
343        self.buffer.clear();
344        self.prev_tool_call_arr.clear();
345        self.current_tool_id = -1;
346        self.streamed_args_for_tool.clear();
347    }
348}