tool_parser/parsers/
deepseek.rs1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7 errors::{ParserError, ParserResult},
8 parsers::helpers,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13pub struct DeepSeekParser {
25 tool_call_extractor: Regex,
27 func_detail_extractor: Regex,
29 partial_tool_call_regex: Regex,
31 tool_call_end_pattern: Regex,
33
34 buffer: String,
36
37 prev_tool_call_arr: Vec<Value>,
39
40 current_tool_id: i32,
42
43 current_tool_name_sent: bool,
45
46 streamed_args_for_tool: Vec<String>,
48}
49
50impl DeepSeekParser {
51 pub fn new() -> Self {
53 let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
55 let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
56
57 let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
58 let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
59
60 let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
62 let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
63
64 let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
66 let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
67
68 Self {
69 tool_call_extractor,
70 func_detail_extractor,
71 partial_tool_call_regex,
72 tool_call_end_pattern,
73 buffer: String::new(),
74 prev_tool_call_arr: Vec::new(),
75 current_tool_id: -1,
76 current_tool_name_sent: false,
77 streamed_args_for_tool: Vec::new(),
78 }
79 }
80
81 fn parse_tool_call(&self, block: &str) -> ParserResult<ToolCall> {
83 let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
84 ParserError::ParsingFailed("Failed to match tool call pattern".to_string())
85 })?;
86
87 let func_type = captures.get(1).map_or("", |m| m.as_str());
89 if func_type != "function" {
90 return Err(ParserError::ParsingFailed(format!(
91 "Invalid function type: {}",
92 func_type
93 )));
94 }
95
96 let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
98 if func_name.is_empty() {
99 return Err(ParserError::ParsingFailed(
100 "Empty function name".to_string(),
101 ));
102 }
103
104 let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
106
107 let value = serde_json::from_str::<Value>(json_args)
109 .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
110
111 let args = if value.is_object() {
113 value
114 } else {
115 serde_json::json!({ "value": value })
117 };
118
119 let arguments =
120 serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
121
122 Ok(ToolCall {
123 function: FunctionCall {
124 name: func_name.to_string(),
125 arguments,
126 },
127 })
128 }
129}
130
131impl Default for DeepSeekParser {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[async_trait]
138impl ToolParser for DeepSeekParser {
139 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
140 if !self.has_tool_markers(text) {
141 return Ok((text.to_string(), vec![]));
142 }
143
144 let idx = text.find("<|tool▁calls▁begin|>").unwrap();
146 let normal_text = text[..idx].to_string();
147
148 let mut tools = Vec::new();
150 for mat in self.tool_call_extractor.find_iter(text) {
151 match self.parse_tool_call(mat.as_str()) {
152 Ok(tool) => tools.push(tool),
153 Err(e) => {
154 tracing::debug!("Failed to parse tool call: {}", e);
155 continue;
156 }
157 }
158 }
159
160 if tools.is_empty() {
162 return Ok((text.to_string(), vec![]));
163 }
164
165 Ok((normal_text, tools))
166 }
167
168 async fn parse_incremental(
169 &mut self,
170 chunk: &str,
171 tools: &[Tool],
172 ) -> ParserResult<StreamingParseResult> {
173 self.buffer.push_str(chunk);
174 let current_text = &self.buffer.clone();
175
176 let has_tool_call =
178 self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
179
180 if !has_tool_call {
181 let mut normal_text = std::mem::take(&mut self.buffer);
184 for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
185 normal_text = normal_text.replace(e_token, "");
186 }
187 return Ok(StreamingParseResult {
188 normal_text,
189 calls: vec![],
190 });
191 }
192
193 let tool_indices = helpers::get_tool_indices(tools);
195
196 let mut calls: Vec<ToolCallItem> = Vec::new();
197
198 if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
200 let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
201 let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
202
203 if !tool_indices.contains_key(func_name) {
205 tracing::debug!("Invalid tool name '{}' - skipping", func_name);
207 helpers::reset_current_tool_state(
208 &mut self.buffer,
209 &mut self.current_tool_name_sent,
210 &mut self.streamed_args_for_tool,
211 &self.prev_tool_call_arr,
212 );
213 return Ok(StreamingParseResult::default());
214 }
215
216 if self.current_tool_id == -1 {
218 self.current_tool_id = 0;
219 self.prev_tool_call_arr = Vec::new();
220 self.streamed_args_for_tool = vec![String::new()];
221 }
222
223 helpers::ensure_capacity(
225 self.current_tool_id,
226 &mut self.prev_tool_call_arr,
227 &mut self.streamed_args_for_tool,
228 );
229
230 if !self.current_tool_name_sent {
232 calls.push(ToolCallItem {
233 tool_index: self.current_tool_id as usize,
234 name: Some(func_name.to_string()),
235 parameters: String::new(),
236 });
237 self.current_tool_name_sent = true;
238
239 let tool_id = self.current_tool_id as usize;
241 if self.prev_tool_call_arr.len() <= tool_id {
242 self.prev_tool_call_arr
243 .resize_with(tool_id + 1, || Value::Null);
244 }
245 self.prev_tool_call_arr[tool_id] = serde_json::json!({
246 "name": func_name,
247 "arguments": {},
248 });
249 } else {
250 let tool_id = self.current_tool_id as usize;
252 let last_sent = self
253 .streamed_args_for_tool
254 .get(tool_id)
255 .map(|s| s.as_str())
256 .unwrap_or("");
257
258 let argument_diff = func_args_raw
259 .strip_prefix(last_sent)
260 .unwrap_or(func_args_raw);
261
262 if !argument_diff.is_empty() {
263 calls.push(ToolCallItem {
264 tool_index: tool_id,
265 name: None,
266 parameters: argument_diff.to_string(),
267 });
268 if tool_id < self.streamed_args_for_tool.len() {
269 self.streamed_args_for_tool[tool_id].push_str(argument_diff);
270 }
271 }
272
273 if helpers::is_complete_json(func_args_raw) {
275 if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
277 let tool_id = self.current_tool_id as usize;
278 if tool_id < self.prev_tool_call_arr.len() {
279 if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
280 obj.insert("arguments".to_string(), parsed_args);
281 }
282 }
283 }
284
285 if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
287 self.buffer = current_text[mat.end()..].to_string();
289 } else {
290 self.buffer.clear();
291 }
292
293 let result = StreamingParseResult {
294 normal_text: String::new(),
295 calls,
296 };
297
298 self.current_tool_id += 1;
299 self.current_tool_name_sent = false;
300 return Ok(result);
301 }
302 }
303 }
304
305 Ok(StreamingParseResult {
306 normal_text: String::new(),
307 calls,
308 })
309 }
310
311 fn has_tool_markers(&self, text: &str) -> bool {
312 text.contains("<|tool▁calls▁begin|>")
313 }
314
315 fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
316 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
317 }
318
319 fn reset(&mut self) {
320 self.buffer.clear();
321 self.prev_tool_call_arr.clear();
322 self.current_tool_id = -1;
323 self.current_tool_name_sent = false;
324 self.streamed_args_for_tool.clear();
325 }
326}