Skip to main content

tool_parser/parsers/
deepseek.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use regex::Regex;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    parsers::helpers,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
11};
12
13/// DeepSeek V3 format parser for tool calls
14///
15/// Handles the DeepSeek V3 specific format that uses Unicode tokens:
16/// `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|><|tool▁calls▁end|>`
17///
18/// Features:
19/// - Unicode token delimiters
20/// - JSON arguments in code blocks
21/// - Support for multiple sequential tool calls
22///
23/// Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
24pub struct DeepSeekParser {
25    /// Regex for extracting complete tool calls
26    tool_call_extractor: Regex,
27    /// Regex for extracting function details
28    func_detail_extractor: Regex,
29    /// Regex for matching partial tool calls during streaming
30    partial_tool_call_regex: Regex,
31    /// Regex pattern for removing completed tool calls from buffer
32    tool_call_end_pattern: Regex,
33
34    /// Buffer for accumulating incomplete patterns across chunks
35    buffer: String,
36
37    /// Stores complete tool call info (name and arguments) for each tool being parsed
38    prev_tool_call_arr: Vec<Value>,
39
40    /// Index of currently streaming tool call (-1 means no active tool)
41    current_tool_id: i32,
42
43    /// Flag for whether current tool's name has been sent to client
44    current_tool_name_sent: bool,
45
46    /// Tracks raw JSON string content streamed to client for each tool's arguments
47    streamed_args_for_tool: Vec<String>,
48}
49
50impl DeepSeekParser {
51    /// Create a new DeepSeek parser
52    pub fn new() -> Self {
53        // Use (?s) flag for DOTALL mode to handle newlines
54        let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
55        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
56
57        let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
58        let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
59
60        // Partial pattern for streaming - uses .* (greedy) not .*? to match all partial content
61        let partial_pattern = r"(?s)<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)";
62        let partial_tool_call_regex = Regex::new(partial_pattern).expect("Valid regex pattern");
63
64        // Pattern for removing completed tool calls
65        let end_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
66        let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern");
67
68        Self {
69            tool_call_extractor,
70            func_detail_extractor,
71            partial_tool_call_regex,
72            tool_call_end_pattern,
73            buffer: String::new(),
74            prev_tool_call_arr: Vec::new(),
75            current_tool_id: -1,
76            current_tool_name_sent: false,
77            streamed_args_for_tool: Vec::new(),
78        }
79    }
80
81    /// Parse a single tool call block - throws error if parsing fails
82    fn parse_tool_call(&self, block: &str) -> ParserResult<ToolCall> {
83        let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
84            ParserError::ParsingFailed("Failed to match tool call pattern".to_string())
85        })?;
86
87        // Get function type (should be "function")
88        let func_type = captures.get(1).map_or("", |m| m.as_str());
89        if func_type != "function" {
90            return Err(ParserError::ParsingFailed(format!(
91                "Invalid function type: {}",
92                func_type
93            )));
94        }
95
96        // Get function name
97        let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
98        if func_name.is_empty() {
99            return Err(ParserError::ParsingFailed(
100                "Empty function name".to_string(),
101            ));
102        }
103
104        // Get JSON arguments
105        let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
106
107        // Parse JSON arguments
108        let value = serde_json::from_str::<Value>(json_args)
109            .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
110
111        // Create arguments object
112        let args = if value.is_object() {
113            value
114        } else {
115            // If not an object, wrap it
116            serde_json::json!({ "value": value })
117        };
118
119        let arguments =
120            serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
121
122        Ok(ToolCall {
123            function: FunctionCall {
124                name: func_name.to_string(),
125                arguments,
126            },
127        })
128    }
129}
130
131impl Default for DeepSeekParser {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137#[async_trait]
138impl ToolParser for DeepSeekParser {
139    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
140        if !self.has_tool_markers(text) {
141            return Ok((text.to_string(), vec![]));
142        }
143
144        // Find where tool calls begin
145        let idx = text.find("<|tool▁calls▁begin|>").unwrap();
146        let normal_text = text[..idx].to_string();
147
148        // Try to extract tool calls, log warnings for failures
149        let mut tools = Vec::new();
150        for mat in self.tool_call_extractor.find_iter(text) {
151            match self.parse_tool_call(mat.as_str()) {
152                Ok(tool) => tools.push(tool),
153                Err(e) => {
154                    tracing::debug!("Failed to parse tool call: {}", e);
155                    continue;
156                }
157            }
158        }
159
160        // If no tools were successfully parsed despite having markers, return entire text as fallback
161        if tools.is_empty() {
162            return Ok((text.to_string(), vec![]));
163        }
164
165        Ok((normal_text, tools))
166    }
167
168    async fn parse_incremental(
169        &mut self,
170        chunk: &str,
171        tools: &[Tool],
172    ) -> ParserResult<StreamingParseResult> {
173        self.buffer.push_str(chunk);
174        let current_text = &self.buffer.clone();
175
176        // Check if we have a tool call (either the start token or individual tool call)
177        let has_tool_call =
178            self.has_tool_markers(current_text) || current_text.contains("<|tool▁call▁begin|>");
179
180        if !has_tool_call {
181            // No tool markers detected - return all buffered content as normal text
182            // Strip out end tokens if present
183            let mut normal_text = std::mem::take(&mut self.buffer);
184            for e_token in ["<|tool▁calls▁end|>", "```", "<|tool▁call▁end|>"] {
185                normal_text = normal_text.replace(e_token, "");
186            }
187            return Ok(StreamingParseResult {
188                normal_text,
189                calls: vec![],
190            });
191        }
192
193        // Build tool indices for validation
194        let tool_indices = helpers::get_tool_indices(tools);
195
196        let mut calls: Vec<ToolCallItem> = Vec::new();
197
198        // Try to match the partial tool call pattern
199        if let Some(captures) = self.partial_tool_call_regex.captures(current_text) {
200            let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
201            let func_args_raw = captures.get(3).map_or("", |m| m.as_str()).trim();
202
203            // Validate tool name
204            if !tool_indices.contains_key(func_name) {
205                // Invalid tool name - skip this tool, preserve indexing for next tool
206                tracing::debug!("Invalid tool name '{}' - skipping", func_name);
207                helpers::reset_current_tool_state(
208                    &mut self.buffer,
209                    &mut self.current_tool_name_sent,
210                    &mut self.streamed_args_for_tool,
211                    &self.prev_tool_call_arr,
212                );
213                return Ok(StreamingParseResult::default());
214            }
215
216            // Initialize state if this is the first tool call
217            if self.current_tool_id == -1 {
218                self.current_tool_id = 0;
219                self.prev_tool_call_arr = Vec::new();
220                self.streamed_args_for_tool = vec![String::new()];
221            }
222
223            // Ensure we have enough entries in our tracking arrays
224            helpers::ensure_capacity(
225                self.current_tool_id,
226                &mut self.prev_tool_call_arr,
227                &mut self.streamed_args_for_tool,
228            );
229
230            // Send tool name if not sent yet
231            if !self.current_tool_name_sent {
232                calls.push(ToolCallItem {
233                    tool_index: self.current_tool_id as usize,
234                    name: Some(func_name.to_string()),
235                    parameters: String::new(),
236                });
237                self.current_tool_name_sent = true;
238
239                // Store the tool call info for serving layer completions endpoint
240                let tool_id = self.current_tool_id as usize;
241                if self.prev_tool_call_arr.len() <= tool_id {
242                    self.prev_tool_call_arr
243                        .resize_with(tool_id + 1, || Value::Null);
244                }
245                self.prev_tool_call_arr[tool_id] = serde_json::json!({
246                    "name": func_name,
247                    "arguments": {},
248                });
249            } else {
250                // Compute incremental diff
251                let tool_id = self.current_tool_id as usize;
252                let last_sent = self
253                    .streamed_args_for_tool
254                    .get(tool_id)
255                    .map(|s| s.as_str())
256                    .unwrap_or("");
257
258                let argument_diff = func_args_raw
259                    .strip_prefix(last_sent)
260                    .unwrap_or(func_args_raw);
261
262                if !argument_diff.is_empty() {
263                    calls.push(ToolCallItem {
264                        tool_index: tool_id,
265                        name: None,
266                        parameters: argument_diff.to_string(),
267                    });
268                    if tool_id < self.streamed_args_for_tool.len() {
269                        self.streamed_args_for_tool[tool_id].push_str(argument_diff);
270                    }
271                }
272
273                // Check if JSON is complete
274                if helpers::is_complete_json(func_args_raw) {
275                    // Update the stored arguments
276                    if let Ok(parsed_args) = serde_json::from_str::<Value>(func_args_raw) {
277                        let tool_id = self.current_tool_id as usize;
278                        if tool_id < self.prev_tool_call_arr.len() {
279                            if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
280                                obj.insert("arguments".to_string(), parsed_args);
281                            }
282                        }
283                    }
284
285                    // Find the end of the current tool call and remove only that part from buffer
286                    if let Some(mat) = self.tool_call_end_pattern.find(current_text) {
287                        // Remove the completed tool call from buffer, keep any remaining content
288                        self.buffer = current_text[mat.end()..].to_string();
289                    } else {
290                        self.buffer.clear();
291                    }
292
293                    let result = StreamingParseResult {
294                        normal_text: String::new(),
295                        calls,
296                    };
297
298                    self.current_tool_id += 1;
299                    self.current_tool_name_sent = false;
300                    return Ok(result);
301                }
302            }
303        }
304
305        Ok(StreamingParseResult {
306            normal_text: String::new(),
307            calls,
308        })
309    }
310
311    fn has_tool_markers(&self, text: &str) -> bool {
312        text.contains("<|tool▁calls▁begin|>")
313    }
314
315    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
316        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
317    }
318
319    fn reset(&mut self) {
320        self.buffer.clear();
321        self.prev_tool_call_arr.clear();
322        self.current_tool_id = -1;
323        self.current_tool_name_sent = false;
324        self.streamed_args_for_tool.clear();
325    }
326}