tool_parser/parsers/
kimik2.rs1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7 errors::ParserResult,
8 parsers::helpers,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13pub struct KimiK2Parser {
25 tool_call_extractor: Regex,
27 stream_tool_call_extractor: Regex,
29 tool_call_end_pattern: Regex,
31 tool_call_id_regex: Regex,
33
34 buffer: String,
36
37 prev_tool_call_arr: Vec<Value>,
39
40 current_tool_id: i32,
42
43 current_tool_name_sent: bool,
45
46 streamed_args_for_tool: Vec<String>,
48
49 last_arguments: String,
51}
52
53impl KimiK2Parser {
54 pub fn new() -> Self {
56 let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>";
58 let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
59
60 let stream_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)";
62 let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern");
63
64 let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>";
66 let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
67
68 let id_pattern = r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$";
70 let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern");
71
72 Self {
73 tool_call_extractor,
74 stream_tool_call_extractor,
75 tool_call_end_pattern,
76 tool_call_id_regex,
77 buffer: String::new(),
78 prev_tool_call_arr: Vec::new(),
79 current_tool_id: -1,
80 current_tool_name_sent: false,
81 streamed_args_for_tool: Vec::new(),
82 last_arguments: String::new(),
83 }
84 }
85
86 fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
88 if let Some(captures) = self.tool_call_id_regex.captures(id) {
89 let name = captures.name("name")?.as_str().to_string();
90 let index = captures.name("index")?.as_str().parse::<usize>().ok()?;
91 Some((name, index))
92 } else {
93 None
94 }
95 }
96}
97
98impl Default for KimiK2Parser {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[async_trait]
105impl ToolParser for KimiK2Parser {
106 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
107 if !self.has_tool_markers(text) {
108 return Ok((text.to_string(), vec![]));
109 }
110
111 let idx = text.find("<|tool_calls_section_begin|>").unwrap();
113 let normal_text = text[..idx].to_string();
114
115 let mut tools = Vec::new();
117 for captures in self.tool_call_extractor.captures_iter(text) {
118 if let (Some(id_match), Some(args_match)) = (
119 captures.name("tool_call_id"),
120 captures.name("function_arguments"),
121 ) {
122 let function_id = id_match.as_str();
123 let function_args = args_match.as_str();
124
125 if let Some((func_name, _index)) = self.parse_function_id(function_id) {
127 match serde_json::from_str::<Value>(function_args) {
129 Ok(_) => {
130 tools.push(ToolCall {
131 function: FunctionCall {
132 name: func_name,
133 arguments: function_args.to_string(),
134 },
135 });
136 }
137 Err(e) => {
138 tracing::debug!(
139 "Failed to parse JSON arguments for {}: {}",
140 func_name,
141 e
142 );
143 continue;
144 }
145 }
146 } else {
147 tracing::debug!("Failed to parse function ID: {}", function_id);
148 continue;
149 }
150 }
151 }
152
153 if tools.is_empty() {
155 return Ok((text.to_string(), vec![]));
156 }
157
158 Ok((normal_text, tools))
159 }
160
161 async fn parse_incremental(
162 &mut self,
163 chunk: &str,
164 tools: &[Tool],
165 ) -> ParserResult<StreamingParseResult> {
166 self.buffer.push_str(chunk);
167 let current_text = &self.buffer.clone();
168
169 let has_tool_call =
171 self.has_tool_markers(current_text) || current_text.contains("<|tool_call_begin|>");
172
173 if !has_tool_call {
174 let mut normal_text = std::mem::take(&mut self.buffer);
176 for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] {
178 normal_text = normal_text.replace(e_token, "");
179 }
180 return Ok(StreamingParseResult {
181 normal_text,
182 calls: vec![],
183 });
184 }
185
186 let tool_indices = helpers::get_tool_indices(tools);
188
189 let mut calls: Vec<ToolCallItem> = Vec::new();
190
191 if let Some(captures) = self.stream_tool_call_extractor.captures(current_text) {
193 if let (Some(id_match), Some(args_match)) = (
194 captures.name("tool_call_id"),
195 captures.name("function_arguments"),
196 ) {
197 let function_id = id_match.as_str();
198 let function_args = args_match.as_str();
199
200 if let Some((func_name, _index)) = self.parse_function_id(function_id) {
202 if !tool_indices.contains_key(&func_name) {
204 tracing::debug!("Invalid tool name '{}' - skipping", func_name);
206 helpers::reset_current_tool_state(
207 &mut self.buffer,
208 &mut self.current_tool_name_sent,
209 &mut self.streamed_args_for_tool,
210 &self.prev_tool_call_arr,
211 );
212 return Ok(StreamingParseResult::default());
213 }
214
215 if self.current_tool_id == -1 {
217 self.current_tool_id = 0;
218 self.prev_tool_call_arr = Vec::new();
219 self.streamed_args_for_tool = vec![String::new()];
220 }
221
222 helpers::ensure_capacity(
224 self.current_tool_id,
225 &mut self.prev_tool_call_arr,
226 &mut self.streamed_args_for_tool,
227 );
228
229 if !self.current_tool_name_sent {
231 calls.push(ToolCallItem {
232 tool_index: self.current_tool_id as usize,
233 name: Some(func_name.clone()),
234 parameters: String::new(),
235 });
236 self.current_tool_name_sent = true;
237
238 let tool_id = self.current_tool_id as usize;
240 if self.prev_tool_call_arr.len() <= tool_id {
241 self.prev_tool_call_arr
242 .resize_with(tool_id + 1, || Value::Null);
243 }
244 self.prev_tool_call_arr[tool_id] = serde_json::json!({
245 "name": func_name,
246 "arguments": {},
247 });
248 } else {
249 let argument_diff = if function_args.starts_with(&self.last_arguments) {
251 &function_args[self.last_arguments.len()..]
252 } else {
253 function_args
254 };
255
256 let parsed_args_diff =
258 if let Some(pos) = argument_diff.find("<|tool_call_end|>") {
259 &argument_diff[..pos]
260 } else {
261 argument_diff
262 };
263
264 if !parsed_args_diff.is_empty() {
265 calls.push(ToolCallItem {
266 tool_index: self.current_tool_id as usize,
267 name: None,
268 parameters: parsed_args_diff.to_string(),
269 });
270 self.last_arguments.push_str(argument_diff);
272 let tool_id = self.current_tool_id as usize;
273 if tool_id < self.streamed_args_for_tool.len() {
274 self.streamed_args_for_tool[tool_id].push_str(parsed_args_diff);
275 }
276 }
277
278 let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>")
280 {
281 &function_args[..pos]
282 } else {
283 function_args
284 };
285
286 if helpers::is_complete_json(parsed_args) {
287 if let Ok(parsed_args_value) =
289 serde_json::from_str::<Value>(parsed_args)
290 {
291 let tool_id = self.current_tool_id as usize;
292 if tool_id < self.prev_tool_call_arr.len() {
293 if let Some(obj) =
294 self.prev_tool_call_arr[tool_id].as_object_mut()
295 {
296 obj.insert("arguments".to_string(), parsed_args_value);
297 }
298 }
299 }
300
301 if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
303 self.buffer = current_text[mat.end()..].to_string();
305 } else {
306 self.buffer.clear();
307 }
308
309 let result = StreamingParseResult {
310 normal_text: String::new(),
311 calls,
312 };
313
314 self.current_tool_id += 1;
315 self.last_arguments.clear();
316 self.current_tool_name_sent = false;
317 return Ok(result);
318 }
319 }
320 }
321 }
322 }
323
324 Ok(StreamingParseResult {
325 normal_text: String::new(),
326 calls,
327 })
328 }
329
330 fn has_tool_markers(&self, text: &str) -> bool {
331 text.contains("<|tool_calls_section_begin|>")
332 }
333
334 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
335 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
336 }
337
338 fn reset(&mut self) {
339 self.buffer.clear();
340 self.prev_tool_call_arr.clear();
341 self.current_tool_id = -1;
342 self.current_tool_name_sent = false;
343 self.streamed_args_for_tool.clear();
344 self.last_arguments.clear();
345 }
346}