tool_parser/parsers/
kimik2.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::ParserResult,
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// Kimi K2 format parser for tool calls
14///
15/// Handles the Kimi K2 specific format:
16/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
17///
18/// Features:
19/// - Token-based delimiters
20/// - Function calls with explicit indexing
21/// - JSON arguments
22///
23/// Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
24pub struct KimiK2Parser {
25    /// Regex for extracting complete tool calls
26    tool_call_extractor: Regex,
27    /// Regex for extracting partial tool calls (streaming)
28    stream_tool_call_extractor: Regex,
29    /// Regex pattern for removing completed tool calls from buffer
30    tool_call_end_pattern: Regex,
31    /// Robust parser for ids like "functions.search:0" or fallback "search:0"
32    tool_call_id_regex: Regex,
33
34    /// Buffer for accumulating incomplete patterns across chunks
35    buffer: String,
36
37    /// Stores complete tool call info (name and arguments) for each tool being parsed
38    prev_tool_call_arr: Vec<Value>,
39
40    /// Index of currently streaming tool call (-1 means no active tool)
41    current_tool_id: i32,
42
43    /// Flag for whether current tool's name has been sent to client
44    current_tool_name_sent: bool,
45
46    /// Tracks raw JSON string content streamed to client for each tool's arguments
47    streamed_args_for_tool: Vec<String>,
48
49    /// Tracks the last arguments sent for incremental diffing
50    last_arguments: String,
51}
52
53impl KimiK2Parser {
54    /// Create a new Kimi K2 parser
55    pub fn new() -> Self {
56        // Pattern for complete tool calls
57        let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
58        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
59
60        // Pattern for streaming (partial) tool calls
61        let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
62        let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
63
64        // Pattern for removing completed tool calls
65        let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
66        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
67
68        // Robust parser for ids like "functions.search:0" or fallback "search:0"
69        let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
70        let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
71
72        Self {
73            tool_call_extractor,
74            stream_tool_call_extractor,
75            tool_call_end_pattern,
76            tool_call_id_regex,
77            buffer: String::new(),
78            prev_tool_call_arr: Vec::new(),
79            current_tool_id: -1,
80            current_tool_name_sent: false,
81            streamed_args_for_tool: Vec::new(),
82            last_arguments: String::new(),
83        }
84    }
85
86    /// Parse function ID to extract name and index
87    fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
88        if let Some(captures) = self.tool_call_id_regex.captures(id) {
89            let name = captures.name("name")?.as_str().to_string();
90            let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
91            Some((name, index))
92        } else {
93            None
94        }
95    }
96}
97
98impl Default for KimiK2Parser {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104#[async_trait]
105impl ToolParser for KimiK2Parser {
106    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
107        if !self.has_tool_markers(text) {
108            return Ok((text.to_string(), vec![]));
109        }
110
111        // Find where tool calls begin
112        let idx = text.find("<|tool_calls_section_begin|>").unwrap();
113        let normal_text = text[..idx].to_string();
114
115        // Try to extract tool calls
116        let mut tools = Vec::new();
117        for captures in self.tool_call_extractor.captures_iter(text) {
118            if let (Some(id_match), Some(args_match)) = (
119                captures.name("tool_call_id"),
120                captures.name("function_arguments"),
121            ) {
122                let function_id = id_match.as_str();
123                let function_args = args_match.as_str();
124
125                // Parse function ID
126                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
127                    // Try to parse JSON arguments
128                    match serde_json::from_str::<Value>(function_args) {
129                        Ok(_) => {
130                            tools.push(ToolCall {
131                                function: FunctionCall {
132                                    name: func_name,
133                                    arguments: function_args.to_string(),
134                                },
135                            });
136                        }
137                        Err(e) => {
138                            tracing::debug!(
139                                "Failed to parse JSON arguments for {}: {}",
140                                func_name,
141                                e
142                            );
143                            continue;
144                        }
145                    }
146                } else {
147                    tracing::debug!("Failed to parse function ID: {}", function_id);
148                    continue;
149                }
150            }
151        }
152
153        // If no tools were successfully parsed despite having markers, return entire text as fallback
154        if tools.is_empty() {
155            return Ok((text.to_string(), vec![]));
156        }
157
158        Ok((normal_text, tools))
159    }
160
161    async fn parse_incremental(
162        &mut self,
163        chunk: &str,
164        tools: &[Tool],
165    ) -> ParserResult<StreamingParseResult> {
166        self.buffer.push_str(chunk);
167        let current_text = &self.buffer.clone();
168
169        // Check if we have a tool call (either the start token or individual tool call)
170        let has_tool_call =
171            self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
172
173        if !has_tool_call {
174            // No tool markers detected - return all buffered content as normal text
175            let mut normal_text = std::mem::take(&mut self.buffer);
176            // Remove end tokens if present
177            for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
178                normal_text = normal_text.replace(e_token, "");
179            }
180            return Ok(StreamingParseResult {
181                normal_text,
182                calls: vec![],
183            });
184        }
185
186        // Build tool indices for validation
187        let tool_indices = helpers::get_tool_indices(tools);
188
189        let mut calls: Vec<ToolCallItem> = Vec::new();
190
191        // Try to match streaming pattern
192        if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
193            if let (Some(id_match), Some(args_match)) = (
194                captures.name("tool_call_id"),
195                captures.name("function_arguments"),
196            ) {
197                let function_id = id_match.as_str();
198                let function_args = args_match.as_str();
199
200                // Parse function ID
201                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
202                    // Validate tool name
203                    if !tool_indices.contains_key(&func_name) {
204                        // Invalid tool name - skip this tool, preserve indexing for next tool
205                        tracing::debug!("Invalid tool name '{}' - skipping", func_name);
206                        helpers::reset_current_tool_state(
207                            &mut self.buffer,
208                            &mut self.current_tool_name_sent,
209                            &mut self.streamed_args_for_tool,
210                            &self.prev_tool_call_arr,
211                        );
212                        return Ok(StreamingParseResult::default());
213                    }
214
215                    // Initialize state if this is the first tool call
216                    if self.current_tool_id == -1 {
217                        self.current_tool_id = 0;
218                        self.prev_tool_call_arr = Vec::new();
219                        self.streamed_args_for_tool = vec![String::new()];
220                    }
221
222                    // Ensure we have enough entries in our tracking arrays
223                    helpers::ensure_capacity(
224                        self.current_tool_id,
225                        &mut self.prev_tool_call_arr,
226                        &mut self.streamed_args_for_tool,
227                    );
228
229                    // Send tool name if not sent yet
230                    if !self.current_tool_name_sent {
231                        calls.push(ToolCallItem {
232                            tool_index: self.current_tool_id as usize,
233                            name: Some(func_name.clone()),
234                            parameters: String::new(),
235                        });
236                        self.current_tool_name_sent = true;
237
238                        // Store the tool call info for serving layer completions endpoint
239                        let tool_id = self.current_tool_id as usize;
240                        if self.prev_tool_call_arr.len() <= tool_id {
241                            self.prev_tool_call_arr
242                                .resize_with(tool_id + 1, || Value::Null);
243                        }
244                        self.prev_tool_call_arr[tool_id] = serde_json::json!({
245                            "name": func_name,
246                            "arguments": {},
247                        });
248                    } else {
249                        // Compute incremental diff
250                        let argument_diff = if function_args.starts_with(&self.last_arguments) {
251                            &function_args[self.last_arguments.len()..]
252                        } else {
253                            function_args
254                        };
255
256                        // Split by end token before sending (like Python does)
257                        let parsed_args_diff =
258                            if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
259                                &argument_diff[..pos]
260                            } else {
261                                argument_diff
262                            };
263
264                        if !parsed_args_diff.is_empty() {
265                            calls.push(ToolCallItem {
266                                tool_index: self.current_tool_id as usize,
267                                name: None,
268                                parameters: parsed_args_diff.to_string(),
269                            });
270                            // Note: Python adds full diff to _last_arguments, not just parsed part
271                            self.last_arguments.push_str(argument_diff);
272                            let tool_id = self.current_tool_id as usize;
273                            if tool_id < self.streamed_args_for_tool.len() {
274                                self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
275                            }
276                        }
277
278                        // Check completeness - split by end token first
279                        let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
280                        {
281                            &function_args[..pos]
282                        } else {
283                            function_args
284                        };
285
286                        if helpers::is_complete_json(parsed_args) {
287                            // Update the stored arguments
288                            if let Ok(parsed_args_value) =
289                                serde_json::from_str::<Value>(parsed_args)
290                            {
291                                let tool_id = self.current_tool_id as usize;
292                                if tool_id < self.prev_tool_call_arr.len() {
293                                    if let Some(obj) =
294                                        self.prev_tool_call_arr[tool_id].as_object_mut()
295                                    {
296                                        obj.insert("arguments".to_string(), parsed_args_value);
297                                    }
298                                }
299                            }
300
301                            // Find the end of the current tool call and remove only that part from buffer
302                            if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
303                                // Remove the completed tool call from buffer, keep any remaining content
304                                self.buffer = current_text[mat.end()..].to_string();
305                            } else {
306                                self.buffer.clear();
307                            }
308
309                            let result = StreamingParseResult {
310                                normal_text: String::new(),
311                                calls,
312                            };
313
314                            self.current_tool_id += 1;
315                            self.last_arguments.clear();
316                            self.current_tool_name_sent = false;
317                            return Ok(result);
318                        }
319                    }
320                }
321            }
322        }
323
324        Ok(StreamingParseResult {
325            normal_text: String::new(),
326            calls,
327        })
328    }
329
330    fn has_tool_markers(&self, text: &str) -> bool {
331        text.contains("<|tool_calls_section_begin|>")
332    }
333
334    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
335        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
336    }
337
338    fn reset(&mut self) {
339        self.buffer.clear();
340        self.prev_tool_call_arr.clear();
341        self.current_tool_id = -1;
342        self.current_tool_name_sent = false;
343        self.streamed_args_for_tool.clear();
344        self.last_arguments.clear();
345    }
346}