Skip to main content

tool_parser/parsers/
kimik2.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// Kimi K2 format parser for tool calls
14///
15/// Handles the Kimi K2 specific format:
16/// `<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|><|tool_calls_section_end|>`
17///
18/// Features:
19/// - Token-based delimiters
20/// - Function calls with explicit indexing
21/// - JSON arguments
22///
23/// Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
24pub struct KimiK2Parser {
25    /// Regex for extracting complete tool calls
26    tool_call_extractor: Regex,
27    /// Regex for extracting partial tool calls (streaming)
28    stream_tool_call_extractor: Regex,
29    /// Regex pattern for removing completed tool calls from buffer
30    tool_call_end_pattern: Regex,
31    /// Robust parser for ids like "functions.search:0" or fallback "search:0"
32    tool_call_id_regex: Regex,
33
34    /// Buffer for accumulating incomplete patterns across chunks
35    buffer: String,
36
37    /// Stores complete tool call info (name and arguments) for each tool being parsed
38    prev_tool_call_arr: Vec<Value>,
39
40    /// Index of currently streaming tool call (-1 means no active tool)
41    current_tool_id: i32,
42
43    /// Flag for whether current tool's name has been sent to client
44    current_tool_name_sent: bool,
45
46    /// Tracks raw JSON string content streamed to client for each tool's arguments
47    streamed_args_for_tool: Vec<String>,
48
49    /// Tracks the last arguments sent for incremental diffing
50    last_arguments: String,
51}
52
53impl KimiK2Parser {
54    /// Create a new Kimi K2 parser
55    #[expect(
56        clippy::expect_used,
57        reason = "regex patterns are compile-time string literals"
58    )]
59    pub fn new() -> Self {
60        // Pattern for complete tool calls
61        let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
62        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
63
64        // Pattern for streaming (partial) tool calls
65        let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
66        let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
67
68        // Pattern for removing completed tool calls
69        let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
70        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
71
72        // Robust parser for ids like "functions.search:0" or fallback "search:0"
73        let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
74        let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
75
76        Self {
77            tool_call_extractor,
78            stream_tool_call_extractor,
79            tool_call_end_pattern,
80            tool_call_id_regex,
81            buffer: String::new(),
82            prev_tool_call_arr: Vec::new(),
83            current_tool_id: -1,
84            current_tool_name_sent: false,
85            streamed_args_for_tool: Vec::new(),
86            last_arguments: String::new(),
87        }
88    }
89
90    /// Parse function ID to extract name and index
91    fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
92        if let Some(captures) = self.tool_call_id_regex.captures(id) {
93            let name = captures.name("name")?.as_str().to_string();
94            let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
95            Some((name, index))
96        } else {
97            None
98        }
99    }
100}
101
102impl Default for KimiK2Parser {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108#[async_trait]
109impl ToolParser for KimiK2Parser {
110    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
111        if !self.has_tool_markers(text) {
112            return Ok((text.to_string(), vec![]));
113        }
114
115        // Find where tool calls begin
116        // Safe: has_tool_markers() already confirmed the marker exists
117        let idx = text
118            .find("<|tool_calls_section_begin|>")
119            .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
120        let normal_text = text[..idx].to_string();
121
122        // Try to extract tool calls
123        let mut tools = Vec::new();
124        for captures in self.tool_call_extractor.captures_iter(text) {
125            if let (Some(id_match), Some(args_match)) = (
126                captures.name("tool_call_id"),
127                captures.name("function_arguments"),
128            ) {
129                let function_id = id_match.as_str();
130                let function_args = args_match.as_str();
131
132                // Parse function ID
133                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
134                    // Try to parse JSON arguments
135                    match serde_json::from_str::<Value>(function_args) {
136                        Ok(_) => {
137                            tools.push(ToolCall {
138                                function: FunctionCall {
139                                    name: func_name,
140                                    arguments: function_args.to_string(),
141                                },
142                            });
143                        }
144                        Err(e) => {
145                            tracing::debug!(
146                                "Failed to parse JSON arguments for {}: {}",
147                                func_name,
148                                e
149                            );
150                            continue;
151                        }
152                    }
153                } else {
154                    tracing::debug!("Failed to parse function ID: {}", function_id);
155                    continue;
156                }
157            }
158        }
159
160        // If no tools were successfully parsed despite having markers, return entire text as fallback
161        if tools.is_empty() {
162            return Ok((text.to_string(), vec![]));
163        }
164
165        Ok((normal_text, tools))
166    }
167
168    async fn parse_incremental(
169        &mut self,
170        chunk: &str,
171        tools: &[Tool],
172    ) -> ParserResult<StreamingParseResult> {
173        self.buffer.push_str(chunk);
174        let current_text = &self.buffer.clone();
175
176        // Check if we have a tool call (either the start token or individual tool call)
177        let has_tool_call =
178            self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
179
180        if !has_tool_call {
181            // No tool markers detected - return all buffered content as normal text
182            let mut normal_text = std::mem::take(&mut self.buffer);
183            // Remove end tokens if present
184            for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
185                normal_text = normal_text.replace(e_token, "");
186            }
187            return Ok(StreamingParseResult {
188                normal_text,
189                calls: vec![],
190            });
191        }
192
193        // Build tool indices for validation
194        let tool_indices = helpers::get_tool_indices(tools);
195
196        let mut calls: Vec<ToolCallItem> = Vec::new();
197
198        // Try to match streaming pattern
199        if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
200            if let (Some(id_match), Some(args_match)) = (
201                captures.name("tool_call_id"),
202                captures.name("function_arguments"),
203            ) {
204                let function_id = id_match.as_str();
205                let function_args = args_match.as_str();
206
207                // Parse function ID
208                if let Some((func_name, _index)) = self.parse_function_id(function_id) {
209                    // Validate tool name
210                    if !tool_indices.contains_key(&func_name) {
211                        // Invalid tool name - skip this tool, preserve indexing for next tool
212                        tracing::debug!("Invalid tool name '{}' - skipping", func_name);
213                        helpers::reset_current_tool_state(
214                            &mut self.buffer,
215                            &mut self.current_tool_name_sent,
216                            &mut self.streamed_args_for_tool,
217                            &self.prev_tool_call_arr,
218                        );
219                        return Ok(StreamingParseResult::default());
220                    }
221
222                    // Initialize state if this is the first tool call
223                    if self.current_tool_id == -1 {
224                        self.current_tool_id = 0;
225                        self.prev_tool_call_arr = Vec::new();
226                        self.streamed_args_for_tool = vec![String::new()];
227                    }
228
229                    // Ensure we have enough entries in our tracking arrays
230                    helpers::ensure_capacity(
231                        self.current_tool_id,
232                        &mut self.prev_tool_call_arr,
233                        &mut self.streamed_args_for_tool,
234                    );
235
236                    // Send tool name if not sent yet
237                    if self.current_tool_name_sent {
238                        // Compute incremental diff
239                        let argument_diff = if function_args.starts_with(&self.last_arguments) {
240                            &function_args[self.last_arguments.len()..]
241                        } else {
242                            function_args
243                        };
244
245                        // Split by end token before sending (like Python does)
246                        let parsed_args_diff =
247                            if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
248                                &argument_diff[..pos]
249                            } else {
250                                argument_diff
251                            };
252
253                        if !parsed_args_diff.is_empty() {
254                            calls.push(ToolCallItem {
255                                tool_index: self.current_tool_id as usize,
256                                name: None,
257                                parameters: parsed_args_diff.to_string(),
258                            });
259                            // Note: Python adds full diff to _last_arguments, not just parsed part
260                            self.last_arguments.push_str(argument_diff);
261                            let tool_id = self.current_tool_id as usize;
262                            if tool_id < self.streamed_args_for_tool.len() {
263                                self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
264                            }
265                        }
266
267                        // Check completeness - split by end token first
268                        let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
269                        {
270                            &function_args[..pos]
271                        } else {
272                            function_args
273                        };
274
275                        if helpers::is_complete_json(parsed_args) {
276                            // Update the stored arguments
277                            if let Ok(parsed_args_value) =
278                                serde_json::from_str::<Value>(parsed_args)
279                            {
280                                let tool_id = self.current_tool_id as usize;
281                                if tool_id < self.prev_tool_call_arr.len() {
282                                    if let Some(obj) =
283                                        self.prev_tool_call_arr[tool_id].as_object_mut()
284                                    {
285                                        obj.insert("arguments".to_string(), parsed_args_value);
286                                    }
287                                }
288                            }
289
290                            // Find the end of the current tool call and remove only that part from buffer
291                            if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
292                                // Remove the completed tool call from buffer, keep any remaining content
293                                self.buffer = current_text[mat.end()..].to_string();
294                            } else {
295                                self.buffer.clear();
296                            }
297
298                            let result = StreamingParseResult {
299                                normal_text: String::new(),
300                                calls,
301                            };
302
303                            self.current_tool_id += 1;
304                            self.last_arguments.clear();
305                            self.current_tool_name_sent = false;
306                            return Ok(result);
307                        }
308                    } else {
309                        calls.push(ToolCallItem {
310                            tool_index: self.current_tool_id as usize,
311                            name: Some(func_name.clone()),
312                            parameters: String::new(),
313                        });
314                        self.current_tool_name_sent = true;
315
316                        // Store the tool call info for serving layer completions endpoint
317                        let tool_id = self.current_tool_id as usize;
318                        if self.prev_tool_call_arr.len() <= tool_id {
319                            self.prev_tool_call_arr
320                                .resize_with(tool_id + 1, || Value::Null);
321                        }
322                        self.prev_tool_call_arr[tool_id] = serde_json::json!({
323                            "name": func_name,
324                            "arguments": {},
325                        });
326                    }
327                }
328            }
329        }
330
331        Ok(StreamingParseResult {
332            normal_text: String::new(),
333            calls,
334        })
335    }
336
337    fn has_tool_markers(&self, text: &str) -> bool {
338        text.contains("<|tool_calls_section_begin|>")
339    }
340
341    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
342        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
343    }
344
345    fn reset(&mut self) {
346        self.buffer.clear();
347        self.prev_tool_call_arr.clear();
348        self.current_tool_id = -1;
349        self.current_tool_name_sent = false;
350        self.streamed_args_for_tool.clear();
351        self.last_arguments.clear();
352    }
353}