Skip to main content

tool_parser/parsers/
mistral.rs

1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use serde_json::Value;
4
5use crate::{
6    errors::{ParserError, ParserResult},
7    parsers::helpers,
8    partial_json::PartialJson,
9    traits::ToolParser,
10    types::{FunctionCall, StreamingParseResult, ToolCall},
11};
12
13/// Mistral format parser for tool calls
14///
15/// Handles the Mistral-specific format:
16/// `[TOOL_CALLS] [{"name": "func", "arguments": {...}}, ...]`
17///
18/// Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
19pub struct MistralParser {
20    /// Parser for handling incomplete JSON during streaming
21    partial_json: PartialJson,
22
23    /// Buffer for accumulating incomplete patterns across chunks
24    buffer: String,
25
26    /// Stores complete tool call info (name and arguments) for each tool being parsed
27    prev_tool_call_arr: Vec<Value>,
28
29    /// Index of currently streaming tool call (-1 means no active tool)
30    current_tool_id: i32,
31
32    /// Flag for whether current tool's name has been sent to client
33    current_tool_name_sent: bool,
34
35    /// Tracks raw JSON string content streamed to client for each tool's arguments
36    streamed_args_for_tool: Vec<String>,
37
38    /// Token configuration
39    bot_token: &'static str,
40    eot_token: &'static str,
41    tool_call_separator: &'static str,
42
43    /// Track whether we've already stripped the closing ] bracket
44    array_closed: bool,
45}
46
47impl MistralParser {
48    /// Create a new Mistral parser
49    pub fn new() -> Self {
50        Self {
51            partial_json: PartialJson::default(),
52            buffer: String::new(),
53            prev_tool_call_arr: Vec::new(),
54            current_tool_id: -1,
55            current_tool_name_sent: false,
56            streamed_args_for_tool: Vec::new(),
57            bot_token: "[TOOL_CALLS] [",
58            eot_token: "]",
59            tool_call_separator: ", ",
60            array_closed: false,
61        }
62    }
63
64    fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
65        const BOT_TOKEN: &str = "[TOOL_CALLS] [";
66
67        // Find the start of the token
68        let start_idx = text.find(BOT_TOKEN)?;
69
70        // Start from the opening bracket after [TOOL_CALLS]
71        // The -1 is to include the opening bracket that's part of the token
72        let json_start = start_idx + BOT_TOKEN.len() - 1;
73
74        let mut bracket_count = 0;
75        let mut in_string = false;
76        let mut escape_next = false;
77
78        let bytes = text.as_bytes();
79
80        for i in json_start..text.len() {
81            let char = bytes[i];
82
83            if escape_next {
84                escape_next = false;
85                continue;
86            }
87
88            if char == b'\\' {
89                escape_next = true;
90                continue;
91            }
92
93            if char == b'"' && !escape_next {
94                in_string = !in_string;
95                continue;
96            }
97
98            if !in_string {
99                if char == b'[' {
100                    bracket_count += 1;
101                } else if char == b']' {
102                    bracket_count -= 1;
103                    if bracket_count == 0 {
104                        // Found the matching closing bracket
105                        return Some((start_idx, &text[json_start..=i]));
106                    }
107                }
108            }
109        }
110
111        // Incomplete array (no matching closing bracket found)
112        None
113    }
114
115    /// Parse tool calls from a JSON array
116    fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
117        let value: Value = serde_json::from_str(json_str)
118            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
119
120        let mut tools = Vec::new();
121
122        if let Value::Array(arr) = value {
123            for item in arr.iter() {
124                if let Some(tool) = self.parse_single_object(item)? {
125                    tools.push(tool);
126                }
127            }
128        } else {
129            // Single object case (shouldn't happen with Mistral format, but handle it)
130            if let Some(tool) = self.parse_single_object(&value)? {
131                tools.push(tool);
132            }
133        }
134
135        Ok(tools)
136    }
137
138    /// Parse a single JSON object into a ToolCall
139    fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
140        let name = obj.get("name").and_then(|v| v.as_str());
141
142        if let Some(name) = name {
143            // Get arguments - Mistral uses "arguments" key
144            let empty_obj = Value::Object(serde_json::Map::new());
145            let args = obj.get("arguments").unwrap_or(&empty_obj);
146
147            // Convert arguments to JSON string
148            let arguments = serde_json::to_string(args)
149                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
150
151            Ok(Some(ToolCall {
152                function: FunctionCall {
153                    name: name.to_string(),
154                    arguments,
155                },
156            }))
157        } else {
158            Ok(None)
159        }
160    }
161}
162
163impl Default for MistralParser {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169#[async_trait]
170impl ToolParser for MistralParser {
171    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
172        // Check if text contains Mistral format
173        if !self.has_tool_markers(text) {
174            return Ok((text.to_string(), vec![]));
175        }
176
177        // Extract JSON array from Mistral format with position
178        if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
179            // Extract normal text before BOT_TOKEN
180            let normal_text_before = if start_idx > 0 {
181                text[..start_idx].to_string()
182            } else {
183                String::new()
184            };
185
186            match self.parse_json_array(json_array) {
187                Ok(tools) => Ok((normal_text_before, tools)),
188                Err(e) => {
189                    // If JSON parsing fails, return the original text as normal text
190                    tracing::debug!("Failed to parse tool call: {}", e);
191                    Ok((text.to_string(), vec![]))
192                }
193            }
194        } else {
195            // Markers present but no complete array found
196            Ok((text.to_string(), vec![]))
197        }
198    }
199
200    async fn parse_incremental(
201        &mut self,
202        chunk: &str,
203        tools: &[Tool],
204    ) -> ParserResult<StreamingParseResult> {
205        // Append new text to buffer
206        self.buffer.push_str(chunk);
207        let current_text = &self.buffer.clone();
208
209        // Check if current_text has tool_call
210        let has_tool_start = self.has_tool_markers(current_text)
211            || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
212
213        if !has_tool_start {
214            // Only clear buffer if we're sure no tool call is starting
215            if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
216                let mut normal_text = self.buffer.clone();
217                self.buffer.clear();
218
219                // Strip ] only once (the closing bracket of [TOOL_CALLS] array)
220                // current_tool_id > 0 means we've parsed at least one tool
221                if !self.array_closed
222                    && self.current_tool_id > 0
223                    && normal_text.starts_with(self.eot_token)
224                {
225                    normal_text = normal_text
226                        .strip_prefix(self.eot_token)
227                        .unwrap()
228                        .to_string();
229                    self.array_closed = true;
230                }
231
232                return Ok(StreamingParseResult {
233                    normal_text,
234                    calls: vec![],
235                });
236            } else {
237                // Might be partial bot_token, keep buffering
238                return Ok(StreamingParseResult::default());
239            }
240        }
241
242        // Build tool indices
243        let tool_indices = helpers::get_tool_indices(tools);
244
245        // Determine start index for JSON parsing
246        let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
247            pos + self.bot_token.len()
248        } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
249            self.tool_call_separator.len()
250        } else {
251            0
252        };
253
254        helpers::handle_json_tool_streaming(
255            current_text,
256            start_idx,
257            &mut self.partial_json,
258            &tool_indices,
259            &mut self.buffer,
260            &mut self.current_tool_id,
261            &mut self.current_tool_name_sent,
262            &mut self.streamed_args_for_tool,
263            &mut self.prev_tool_call_arr,
264        )
265    }
266
267    fn has_tool_markers(&self, text: &str) -> bool {
268        text.contains("[TOOL_CALLS]")
269    }
270
271    fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::types::ToolCallItem>> {
272        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
273    }
274
275    fn reset(&mut self) {
276        helpers::reset_parser_state(
277            &mut self.buffer,
278            &mut self.prev_tool_call_arr,
279            &mut self.current_tool_id,
280            &mut self.current_tool_name_sent,
281            &mut self.streamed_args_for_tool,
282        );
283        self.array_closed = false;
284    }
285}