tool_parser/parsers/
glm4_moe.rs1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7 errors::{ParserError, ParserResult},
8 parsers::helpers,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13pub struct Glm4MoeParser {
24 tool_call_extractor: Regex,
26 func_detail_extractor: Regex,
28 arg_extractor: Regex,
30
31 buffer: String,
33
34 prev_tool_call_arr: Vec<Value>,
36
37 current_tool_id: i32,
39
40 streamed_args_for_tool: Vec<String>,
42
43 bot_token: &'static str,
45 eot_token: &'static str,
46}
47
48impl Glm4MoeParser {
49 #[expect(
56 clippy::expect_used,
57 reason = "regex patterns are compile-time string literals"
58 )]
59 pub(crate) fn new(func_detail_pattern: &str) -> Self {
60 let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
62 let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
63
64 let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
65
66 let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
67 let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
68
69 Self {
70 tool_call_extractor,
71 func_detail_extractor,
72 arg_extractor,
73 buffer: String::new(),
74 prev_tool_call_arr: Vec::new(),
75 current_tool_id: -1,
76 streamed_args_for_tool: Vec::new(),
77 bot_token: "<tool_call>",
78 eot_token: "</tool_call>",
79 }
80 }
81
82 pub fn glm45() -> Self {
84 Self::new(r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>")
85 }
86
87 pub fn glm47() -> Self {
89 Self::new(r"(?s)<tool_call>\s*([^<\s]+)\s*(.*?)</tool_call>")
90 }
91
92 fn parse_arguments(&self, args_text: &str) -> serde_json::Map<String, Value> {
94 let mut arguments = serde_json::Map::new();
95
96 for capture in self.arg_extractor.captures_iter(args_text) {
97 let key = capture.get(1).map_or("", |m| m.as_str()).trim();
98 let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
99
100 let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
102 json_val
103 } else {
104 if value_str == "true" || value_str == "True" {
106 Value::Bool(true)
107 } else if value_str == "false" || value_str == "False" {
108 Value::Bool(false)
109 } else if value_str == "null" || value_str == "None" {
110 Value::Null
111 } else if let Ok(num) = value_str.parse::<i64>() {
112 Value::Number(num.into())
113 } else if let Ok(num) = value_str.parse::<f64>() {
114 if let Some(n) = serde_json::Number::from_f64(num) {
115 Value::Number(n)
116 } else {
117 Value::String(value_str.to_string())
118 }
119 } else {
120 Value::String(value_str.to_string())
121 }
122 };
123
124 arguments.insert(key.to_string(), value);
125 }
126
127 arguments
128 }
129
130 fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
132 if let Some(captures) = self.func_detail_extractor.captures(block) {
133 let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
135
136 let args_text = captures.get(2).map_or("", |m| m.as_str());
138
139 let arguments = self.parse_arguments(args_text);
141
142 let arguments_str = serde_json::to_string(&arguments)
143 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
144
145 Ok(Some(ToolCall {
146 function: FunctionCall {
147 name: func_name.to_string(),
148 arguments: arguments_str,
149 },
150 }))
151 } else {
152 Ok(None)
153 }
154 }
155
156 fn parse_tool_calls_from_text(&self, text: &str) -> Vec<ToolCall> {
158 let mut tools = Vec::new();
159
160 for mat in self.tool_call_extractor.find_iter(text) {
161 match self.parse_tool_call(mat.as_str()) {
162 Ok(Some(tool)) => tools.push(tool),
163 Ok(None) => continue,
164 Err(e) => {
165 tracing::debug!("Failed to parse tool call: {}", e);
166 continue;
167 }
168 }
169 }
170
171 tools
172 }
173}
174
175impl Default for Glm4MoeParser {
176 fn default() -> Self {
177 Self::glm45()
178 }
179}
180
181#[async_trait]
182impl ToolParser for Glm4MoeParser {
183 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
184 if !self.has_tool_markers(text) {
186 return Ok((text.to_string(), vec![]));
187 }
188
189 let idx = text
192 .find("<tool_call>")
193 .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
194 let normal_text = text[..idx].to_string();
195
196 let tools = self.parse_tool_calls_from_text(text);
198
199 if tools.is_empty() {
201 return Ok((text.to_string(), vec![]));
202 }
203
204 Ok((normal_text, tools))
205 }
206
207 async fn parse_incremental(
208 &mut self,
209 chunk: &str,
210 tools: &[Tool],
211 ) -> ParserResult<StreamingParseResult> {
212 self.buffer.push_str(chunk);
214 let current_text = &self.buffer.clone();
215
216 let start = current_text.find(self.bot_token);
218 if start.is_none() {
219 self.buffer.clear();
220 let normal_text = if self.current_tool_id > 0 {
222 String::new()
223 } else {
224 current_text.clone()
225 };
226 return Ok(StreamingParseResult {
227 normal_text,
228 calls: vec![],
229 });
230 }
231
232 let end = current_text.find(self.eot_token);
234 if let Some(end_pos) = end {
235 if self.current_tool_id == -1 {
239 self.current_tool_id = 0;
240 self.prev_tool_call_arr = Vec::new();
241 self.streamed_args_for_tool = vec![String::new()];
242 }
243
244 helpers::ensure_capacity(
246 self.current_tool_id,
247 &mut self.prev_tool_call_arr,
248 &mut self.streamed_args_for_tool,
249 );
250
251 let block_end = end_pos + self.eot_token.len();
253 let parsed_tools = self.parse_tool_calls_from_text(¤t_text[..block_end]);
254
255 let idx = current_text.find(self.bot_token);
257 let normal_text = if let Some(pos) = idx {
258 current_text[..pos].trim().to_string()
259 } else {
260 String::new()
261 };
262
263 let tool_indices = helpers::get_tool_indices(tools);
265
266 let mut calls = Vec::new();
267
268 if !parsed_tools.is_empty() {
269 let tool_call = &parsed_tools[0];
271 let tool_id = self.current_tool_id as usize;
272
273 if !tool_indices.contains_key(&tool_call.function.name) {
275 tracing::debug!("Invalid tool name '{}' - skipping", tool_call.function.name);
277 helpers::reset_current_tool_state(
278 &mut self.buffer,
279 &mut false, &mut self.streamed_args_for_tool,
281 &self.prev_tool_call_arr,
282 );
283 return Ok(StreamingParseResult::default());
284 }
285
286 calls.push(ToolCallItem {
287 tool_index: tool_id,
288 name: Some(tool_call.function.name.clone()),
289 parameters: tool_call.function.arguments.clone(),
290 });
291
292 if self.prev_tool_call_arr.len() <= tool_id {
294 self.prev_tool_call_arr
295 .resize_with(tool_id + 1, || Value::Null);
296 }
297
298 if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
300 self.prev_tool_call_arr[tool_id] = serde_json::json!({
301 "name": tool_call.function.name,
302 "arguments": args,
303 });
304 }
305
306 if self.streamed_args_for_tool.len() <= tool_id {
307 self.streamed_args_for_tool
308 .resize_with(tool_id + 1, String::new);
309 }
310 self.streamed_args_for_tool[tool_id].clone_from(&tool_call.function.arguments);
311
312 self.current_tool_id += 1;
313 }
314
315 self.buffer = current_text[block_end..].to_string();
317 return Ok(StreamingParseResult { normal_text, calls });
318 }
319
320 let Some(start_pos) = start else {
323 return Ok(StreamingParseResult::default());
324 };
325 let normal_text = current_text[..start_pos].to_string();
326 self.buffer = current_text[start_pos..].to_string();
327
328 Ok(StreamingParseResult {
329 normal_text,
330 calls: vec![],
331 })
332 }
333
334 fn has_tool_markers(&self, text: &str) -> bool {
335 text.contains(self.bot_token)
336 }
337
338 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
339 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
340 }
341
342 fn reset(&mut self) {
343 self.buffer.clear();
344 self.prev_tool_call_arr.clear();
345 self.current_tool_id = -1;
346 self.streamed_args_for_tool.clear();
347 }
348}