Skip to main content

tool_parser/parsers/
glm4_moe.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// GLM-4 MoE format parser for tool calls
14///
15/// Handles both GLM-4 MoE and GLM-4.7 MoE formats:
16/// - GLM-4: `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
17/// - GLM-4.7: `<tool_call>{name}<arg_key>{key}</arg_key><arg_value>{value}</arg_value></tool_call>`
18///
19/// Features:
20/// - XML-style tags for tool calls
21/// - Key-value pairs for arguments
22/// - Support for multiple sequential tool calls
23pub struct Glm4MoeParser {
24    /// Regex for extracting complete tool calls
25    tool_call_extractor: Regex,
26    /// Regex for extracting function details
27    func_detail_extractor: Regex,
28    /// Regex for extracting argument key-value pairs
29    arg_extractor: Regex,
30
31    /// Buffer for accumulating incomplete patterns across chunks
32    buffer: String,
33
34    /// Stores complete tool call info (name and arguments) for each tool being parsed
35    prev_tool_call_arr: Vec<Value>,
36
37    /// Index of currently streaming tool call (-1 means no active tool)
38    current_tool_id: i32,
39
40    /// Tracks raw JSON string content streamed to client for each tool's arguments
41    streamed_args_for_tool: Vec<String>,
42
43    /// Token configuration
44    bot_token: &'static str,
45    eot_token: &'static str,
46}
47
48impl Glm4MoeParser {
49    /// Create a new generic GLM MoE parser with a custom func_detail_extractor pattern
50    ///
51    /// # Arguments
52    /// - `func_detail_pattern`: Regex pattern for extracting function name and arguments
53    ///   - For GLM-4: `r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>"`
54    ///   - For GLM-4.7: `r"(?s)<tool_call>\s*([^<\s]+)\s*(.*?)</tool_call>"`
55    pub(crate) fn new(func_detail_pattern: &str) -> Self {
56        // Use (?s) flag for DOTALL mode to handle newlines
57        let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
58        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
59
60        let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
61
62        let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
63        let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
64
65        Self {
66            tool_call_extractor,
67            func_detail_extractor,
68            arg_extractor,
69            buffer: String::new(),
70            prev_tool_call_arr: Vec::new(),
71            current_tool_id: -1,
72            streamed_args_for_tool: Vec::new(),
73            bot_token: "<tool_call>",
74            eot_token: "</tool_call>",
75        }
76    }
77
78    /// Create a new GLM-4.5/4.6 MoE parser (with newline-based format)
79    pub fn glm45() -> Self {
80        Self::new(r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>")
81    }
82
83    /// Create a new GLM-4.7 MoE parser (with whitespace-based format)
84    pub fn glm47() -> Self {
85        Self::new(r"(?s)<tool_call>\s*([^<\s]+)\s*(.*?)</tool_call>")
86    }
87
88    /// Parse arguments from key-value pairs
89    fn parse_arguments(&self, args_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
90        let mut arguments = serde_json::Map::new();
91
92        for capture in self.arg_extractor.captures_iter(args_text) {
93            let key = capture.get(1).map_or("", |m| m.as_str()).trim();
94            let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
95
96            // Try to parse the value as JSON first, fallback to string
97            let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
98                json_val
99            } else {
100                // Try parsing as Python literal (similar to Python's ast.literal_eval)
101                if value_str == "true" || value_str == "True" {
102                    Value::Bool(true)
103                } else if value_str == "false" || value_str == "False" {
104                    Value::Bool(false)
105                } else if value_str == "null" || value_str == "None" {
106                    Value::Null
107                } else if let Ok(num) = value_str.parse::<i64>() {
108                    Value::Number(num.into())
109                } else if let Ok(num) = value_str.parse::<f64>() {
110                    if let Some(n) = serde_json::Number::from_f64(num) {
111                        Value::Number(n)
112                    } else {
113                        Value::String(value_str.to_string())
114                    }
115                } else {
116                    Value::String(value_str.to_string())
117                }
118            };
119
120            arguments.insert(key.to_string(), value);
121        }
122
123        Ok(arguments)
124    }
125
126    /// Parse a single tool call block
127    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
128        if let Some(captures) = self.func_detail_extractor.captures(block) {
129            // Get function name
130            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
131
132            // Get arguments text
133            let args_text = captures.get(2).map_or("", |m| m.as_str());
134
135            // Parse arguments
136            let arguments = self.parse_arguments(args_text)?;
137
138            let arguments_str = serde_json::to_string(&arguments)
139                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
140
141            Ok(Some(ToolCall {
142                function: FunctionCall {
143                    name: func_name.to_string(),
144                    arguments: arguments_str,
145                },
146            }))
147        } else {
148            Ok(None)
149        }
150    }
151
152    /// Parse all tool calls from text (shared logic for complete and incremental parsing)
153    fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult<Vec<ToolCall>> {
154        let mut tools = Vec::new();
155
156        for mat in self.tool_call_extractor.find_iter(text) {
157            match self.parse_tool_call(mat.as_str()) {
158                Ok(Some(tool)) => tools.push(tool),
159                Ok(None) => continue,
160                Err(e) => {
161                    tracing::debug!("Failed to parse tool call: {}", e);
162                    continue;
163                }
164            }
165        }
166
167        Ok(tools)
168    }
169}
170
171impl Default for Glm4MoeParser {
172    fn default() -> Self {
173        Self::glm45()
174    }
175}
176
177#[async_trait]
178impl ToolParser for Glm4MoeParser {
179    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
180        // Check if text contains GLM-4 MoE format
181        if !self.has_tool_markers(text) {
182            return Ok((text.to_string(), vec![]));
183        }
184
185        // Find where tool calls begin
186        let idx = text.find("<tool_call>").unwrap();
187        let normal_text = text[..idx].to_string();
188
189        // Parse all tool calls using shared helper
190        let tools = self.parse_tool_calls_from_text(text)?;
191
192        // If no tools were successfully parsed despite having markers, return entire text as fallback
193        if tools.is_empty() {
194            return Ok((text.to_string(), vec![]));
195        }
196
197        Ok((normal_text, tools))
198    }
199
200    async fn parse_incremental(
201        &mut self,
202        chunk: &str,
203        tools: &[Tool],
204    ) -> ParserResult<StreamingParseResult> {
205        // Python logic: Wait for complete tool call, then parse it all at once
206        self.buffer.push_str(chunk);
207        let current_text = &self.buffer.clone();
208
209        // Check if we have bot_token
210        let start = current_text.find(self.bot_token);
211        if start.is_none() {
212            self.buffer.clear();
213            // If we're in the middle of streaming (current_tool_id > 0), don't return text
214            let normal_text = if self.current_tool_id > 0 {
215                String::new()
216            } else {
217                current_text.clone()
218            };
219            return Ok(StreamingParseResult {
220                normal_text,
221                calls: vec![],
222            });
223        }
224
225        // Check if we have eot_token (end of tool call)
226        let end = current_text.find(self.eot_token);
227        if let Some(end_pos) = end {
228            // We have a complete tool call!
229
230            // Initialize state if this is the first tool call
231            if self.current_tool_id == -1 {
232                self.current_tool_id = 0;
233                self.prev_tool_call_arr = Vec::new();
234                self.streamed_args_for_tool = vec![String::new()];
235            }
236
237            // Ensure we have enough entries in our tracking arrays
238            helpers::ensure_capacity(
239                self.current_tool_id,
240                &mut self.prev_tool_call_arr,
241                &mut self.streamed_args_for_tool,
242            );
243
244            // Parse the complete block using shared helper
245            let block_end = end_pos + self.eot_token.len();
246            let parsed_tools = self.parse_tool_calls_from_text(&current_text[..block_end])?;
247
248            // Extract normal text before tool calls
249            let idx = current_text.find(self.bot_token);
250            let normal_text = if let Some(pos) = idx {
251                current_text[..pos].trim().to_string()
252            } else {
253                String::new()
254            };
255
256            // Build tool indices for validation
257            let tool_indices = helpers::get_tool_indices(tools);
258
259            let mut calls = Vec::new();
260
261            if !parsed_tools.is_empty() {
262                // Take the first tool and convert to ToolCallItem
263                let tool_call = &parsed_tools[0];
264                let tool_id = self.current_tool_id as usize;
265
266                // Validate tool name
267                if !tool_indices.contains_key(&tool_call.function.name) {
268                    // Invalid tool name - skip this tool, preserve indexing for next tool
269                    tracing::debug!("Invalid tool name '{}' - skipping", tool_call.function.name);
270                    helpers::reset_current_tool_state(
271                        &mut self.buffer,
272                        &mut false, // glm45_moe/glm47_moe doesn't track name_sent per tool
273                        &mut self.streamed_args_for_tool,
274                        &self.prev_tool_call_arr,
275                    );
276                    return Ok(StreamingParseResult::default());
277                }
278
279                calls.push(ToolCallItem {
280                    tool_index: tool_id,
281                    name: Some(tool_call.function.name.clone()),
282                    parameters: tool_call.function.arguments.clone(),
283                });
284
285                // Store in tracking arrays
286                if self.prev_tool_call_arr.len() <= tool_id {
287                    self.prev_tool_call_arr
288                        .resize_with(tool_id + 1, || Value::Null);
289                }
290
291                // Parse parameters as JSON and store
292                if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
293                    self.prev_tool_call_arr[tool_id] = serde_json::json!({
294                        "name": tool_call.function.name,
295                        "arguments": args,
296                    });
297                }
298
299                if self.streamed_args_for_tool.len() <= tool_id {
300                    self.streamed_args_for_tool
301                        .resize_with(tool_id + 1, String::new);
302                }
303                self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();
304
305                self.current_tool_id += 1;
306            }
307
308            // Remove processed portion from buffer
309            self.buffer = current_text[block_end..].to_string();
310            return Ok(StreamingParseResult { normal_text, calls });
311        }
312
313        // No complete tool call yet - return normal text before start token
314        let start_pos = start.unwrap();
315        let normal_text = current_text[..start_pos].to_string();
316        self.buffer = current_text[start_pos..].to_string();
317
318        Ok(StreamingParseResult {
319            normal_text,
320            calls: vec![],
321        })
322    }
323
324    fn has_tool_markers(&self, text: &str) -> bool {
325        text.contains(self.bot_token)
326    }
327
328    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
329        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
330    }
331
332    fn reset(&mut self) {
333        self.buffer.clear();
334        self.prev_tool_call_arr.clear();
335        self.current_tool_id = -1;
336        self.streamed_args_for_tool.clear();
337    }
338}