Skip to main content

tool_parser/parsers/
cohere.rs

1//! Cohere Command model tool call parser
2//!
3//! Parses tool calls from `<|START_ACTION|>...<|END_ACTION|>` blocks.
4//! Supports both CMD3 and CMD4 formats.
5//!
6//! # Format
7//!
8//! Cohere models output tool calls in the following format:
9//! ```text
10//! <|START_RESPONSE|>Let me help with that.<|END_RESPONSE|>
11//! <|START_ACTION|>
12//! {"tool_name": "search", "parameters": {"query": "rust programming"}}
13//! <|END_ACTION|>
14//! ```
15//!
16//! Or for multiple tool calls:
17//! ```text
18//! <|START_ACTION|>
19//! [
20//!   {"tool_name": "search", "parameters": {"query": "rust"}},
21//!   {"tool_name": "get_weather", "parameters": {"city": "Paris"}}
22//! ]
23//! <|END_ACTION|>
24//! ```
25//!
26//! # Field Mapping
27//! - `tool_name` → `name`
28//! - `parameters` → `arguments`
29
30use async_trait::async_trait;
31use openai_protocol::common::Tool;
32use serde_json::Value;
33
34use crate::{
35    errors::{ParserError, ParserResult},
36    parsers::helpers,
37    partial_json::PartialJson,
38    traits::ToolParser,
39    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
40};
41
42const START_ACTION: &str = "<|START_ACTION|>";
43const END_ACTION: &str = "<|END_ACTION|>";
44const START_RESPONSE: &str = "<|START_RESPONSE|>";
45const END_RESPONSE: &str = "<|END_RESPONSE|>";
46const START_TEXT: &str = "<|START_TEXT|>";
47const END_TEXT: &str = "<|END_TEXT|>";
48
49/// State machine for Cohere tool parsing
50#[derive(Debug, Clone, Copy, PartialEq)]
51enum ParseState {
52    /// Looking for START_ACTION marker
53    Text,
54    /// Inside an action block, parsing JSON
55    InAction,
56}
57
58/// Cohere Command model tool call parser
59///
60/// Handles the Cohere-specific format:
61/// `<|START_ACTION|>{"tool_name": "func", "parameters": {...}}<|END_ACTION|>`
62pub struct CohereParser {
63    /// Current parsing state
64    state: ParseState,
65
66    /// Parser for handling incomplete JSON during streaming
67    partial_json: PartialJson,
68
69    /// Buffer for accumulating incomplete patterns across chunks
70    buffer: String,
71
72    /// Stores complete tool call info (name and arguments) for each tool being parsed
73    prev_tool_call_arr: Vec<Value>,
74
75    /// Index of currently streaming tool call (-1 means no active tool)
76    current_tool_id: i32,
77
78    /// Flag for whether current tool's name has been sent to client
79    current_tool_name_sent: bool,
80
81    /// Tracks raw JSON string content streamed to client for each tool's arguments
82    streamed_args_for_tool: Vec<String>,
83}
84
85impl CohereParser {
86    /// Create a new Cohere parser
87    pub fn new() -> Self {
88        Self {
89            state: ParseState::Text,
90            partial_json: PartialJson::default(),
91            buffer: String::new(),
92            prev_tool_call_arr: Vec::new(),
93            current_tool_id: -1,
94            current_tool_name_sent: false,
95            streamed_args_for_tool: Vec::new(),
96        }
97    }
98
99    /// Clean text by removing response markers
100    fn clean_text(text: &str) -> String {
101        text.replace(START_RESPONSE, "")
102            .replace(END_RESPONSE, "")
103            .replace(START_TEXT, "")
104            .replace(END_TEXT, "")
105    }
106
107    /// Convert a Cohere tool call JSON object to our ToolCall format
108    fn convert_tool_call(json_str: &str) -> ParserResult<Vec<ToolCall>> {
109        let value: Value = serde_json::from_str(json_str.trim())
110            .map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {e}")))?;
111
112        let tools = match value {
113            Value::Array(arr) => arr,
114            single => vec![single],
115        };
116
117        tools
118            .into_iter()
119            .filter_map(|tool| {
120                // Cohere uses "tool_name" instead of "name"
121                let name = tool
122                    .get("tool_name")
123                    .and_then(|v| v.as_str())
124                    .or_else(|| tool.get("name").and_then(|v| v.as_str()))?;
125
126                // Cohere uses "parameters" instead of "arguments"
127                let parameters = tool
128                    .get("parameters")
129                    .or_else(|| tool.get("arguments"))
130                    .map(|v| v.to_string())
131                    .unwrap_or_else(|| "{}".to_string());
132
133                Some(Ok(ToolCall {
134                    function: FunctionCall {
135                        name: name.to_string(),
136                        arguments: parameters,
137                    },
138                }))
139            })
140            .collect()
141    }
142
143    /// Extract JSON content between START_ACTION and END_ACTION
144    fn extract_action_json(text: &str) -> Option<(usize, &str, usize)> {
145        let start_idx = text.find(START_ACTION)?;
146        let json_start = start_idx + START_ACTION.len();
147
148        if let Some(end_offset) = text[json_start..].find(END_ACTION) {
149            let json_str = &text[json_start..json_start + end_offset];
150            Some((
151                start_idx,
152                json_str,
153                json_start + end_offset + END_ACTION.len(),
154            ))
155        } else {
156            // Incomplete - no END_ACTION yet
157            None
158        }
159    }
160}
161
162impl Default for CohereParser {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168#[async_trait]
169impl ToolParser for CohereParser {
170    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
171        // Check if text contains Cohere format
172        if !self.has_tool_markers(text) {
173            let cleaned = Self::clean_text(text);
174            return Ok((cleaned.trim().to_string(), vec![]));
175        }
176
177        let mut normal_text = String::new();
178        let mut tool_calls = Vec::new();
179        let mut remaining = text;
180
181        while let Some((start_idx, json_str, end_idx)) = Self::extract_action_json(remaining) {
182            // Text before action
183            normal_text.push_str(&remaining[..start_idx]);
184
185            // Parse tool calls from this action block
186            match Self::convert_tool_call(json_str) {
187                Ok(calls) => tool_calls.extend(calls),
188                Err(e) => {
189                    tracing::debug!("Failed to parse Cohere tool call: {}", e);
190                }
191            }
192
193            remaining = &remaining[end_idx..];
194        }
195
196        // Append any remaining text after last action block
197        normal_text.push_str(remaining);
198
199        // Clean up response markers
200        let cleaned_text = Self::clean_text(&normal_text);
201
202        Ok((cleaned_text.trim().to_string(), tool_calls))
203    }
204
205    async fn parse_incremental(
206        &mut self,
207        chunk: &str,
208        tools: &[Tool],
209    ) -> ParserResult<StreamingParseResult> {
210        self.buffer.push_str(chunk);
211
212        match self.state {
213            ParseState::Text => {
214                // Check for START_ACTION marker
215                let start_pos = self.buffer.find(START_ACTION);
216                if let Some(pos) = start_pos {
217                    // Emit text before the action as normal text
218                    let text_before = Self::clean_text(&self.buffer[..pos]);
219
220                    // Switch to InAction state and keep only content after START_ACTION
221                    self.state = ParseState::InAction;
222                    self.buffer.drain(..pos + START_ACTION.len());
223
224                    return Ok(StreamingParseResult {
225                        normal_text: text_before,
226                        calls: vec![],
227                    });
228                }
229
230                // Check for partial START_ACTION
231                if helpers::ends_with_partial_token(&self.buffer, START_ACTION).is_some() {
232                    // Keep buffering
233                    return Ok(StreamingParseResult::default());
234                }
235
236                // No action starting, emit cleaned text
237                let cleaned = Self::clean_text(&self.buffer);
238                self.buffer.clear();
239                Ok(StreamingParseResult {
240                    normal_text: cleaned,
241                    calls: vec![],
242                })
243            }
244
245            ParseState::InAction => {
246                // Check if we have END_ACTION
247                if let Some(pos) = self.buffer.find(END_ACTION) {
248                    // We have complete JSON - extract it before modifying buffer
249                    let json_content = self.buffer[..pos].to_string();
250
251                    // Build tool indices
252                    let tool_indices = helpers::get_tool_indices(tools);
253
254                    // Create a temporary buffer for the helper (it expects to manage buffer state)
255                    let mut temp_buffer = String::new();
256
257                    // Use helper for streaming - pass JSON directly with offset 0
258                    let result = helpers::handle_json_tool_streaming(
259                        &json_content,
260                        0,
261                        &mut self.partial_json,
262                        &tool_indices,
263                        &mut temp_buffer,
264                        &mut self.current_tool_id,
265                        &mut self.current_tool_name_sent,
266                        &mut self.streamed_args_for_tool,
267                        &mut self.prev_tool_call_arr,
268                    )?;
269
270                    // Move past END_ACTION and switch back to Text state
271                    self.buffer.drain(..pos + END_ACTION.len());
272                    self.state = ParseState::Text;
273
274                    return Ok(result);
275                }
276
277                // Partial JSON - buffer and wait for END_ACTION
278                // Unlike formats without end markers, we can't stream partial JSON safely
279                // because we don't know if the JSON is complete until we see END_ACTION
280                Ok(StreamingParseResult::default())
281            }
282        }
283    }
284
285    fn has_tool_markers(&self, text: &str) -> bool {
286        text.contains(START_ACTION) || text.contains(END_ACTION)
287    }
288
289    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
290        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
291    }
292
293    fn reset(&mut self) {
294        self.state = ParseState::Text;
295        helpers::reset_parser_state(
296            &mut self.buffer,
297            &mut self.prev_tool_call_arr,
298            &mut self.current_tool_id,
299            &mut self.current_tool_name_sent,
300            &mut self.streamed_args_for_tool,
301        );
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[tokio::test]
310    async fn test_single_tool_call() {
311        let parser = CohereParser::new();
312        let input = r#"<|START_RESPONSE|>Let me search for that.<|END_RESPONSE|>
313<|START_ACTION|>
314{"tool_name": "search", "parameters": {"query": "rust programming"}}
315<|END_ACTION|>"#;
316
317        let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
318        assert_eq!(tools.len(), 1);
319        assert_eq!(normal_text, "Let me search for that.");
320        assert_eq!(tools[0].function.name, "search");
321
322        let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
323        assert_eq!(args["query"], "rust programming");
324    }
325
326    #[tokio::test]
327    async fn test_multiple_tool_calls_array() {
328        let parser = CohereParser::new();
329        let input = r#"<|START_ACTION|>
330[
331  {"tool_name": "search", "parameters": {"query": "rust"}},
332  {"tool_name": "get_weather", "parameters": {"city": "Paris"}}
333]
334<|END_ACTION|>"#;
335
336        let (_, tools) = parser.parse_complete(input).await.unwrap();
337        assert_eq!(tools.len(), 2);
338        assert_eq!(tools[0].function.name, "search");
339        assert_eq!(tools[1].function.name, "get_weather");
340    }
341
342    #[tokio::test]
343    async fn test_no_tool_calls() {
344        let parser = CohereParser::new();
345        let input = "<|START_RESPONSE|>Hello, how can I help?<|END_RESPONSE|>";
346
347        let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
348        assert_eq!(tools.len(), 0);
349        assert_eq!(normal_text, "Hello, how can I help?");
350    }
351
352    #[tokio::test]
353    async fn test_has_tool_markers() {
354        let parser = CohereParser::new();
355
356        assert!(parser.has_tool_markers("<|START_ACTION|>"));
357        assert!(parser.has_tool_markers("<|END_ACTION|>"));
358        assert!(parser.has_tool_markers("Some text <|START_ACTION|> more"));
359        assert!(!parser.has_tool_markers("Just plain text"));
360        assert!(!parser.has_tool_markers("[TOOL_CALLS]")); // Mistral format
361    }
362
363    #[tokio::test]
364    async fn test_empty_parameters() {
365        let parser = CohereParser::new();
366        let input = r#"<|START_ACTION|>{"tool_name": "ping"}<|END_ACTION|>"#;
367
368        let (_, tools) = parser.parse_complete(input).await.unwrap();
369        assert_eq!(tools.len(), 1);
370        assert_eq!(tools[0].function.name, "ping");
371        assert_eq!(tools[0].function.arguments, "{}");
372    }
373
374    #[tokio::test]
375    async fn test_nested_json() {
376        let parser = CohereParser::new();
377        let input = r#"<|START_ACTION|>
378{"tool_name": "process", "parameters": {"config": {"nested": {"value": [1, 2, 3]}}}}
379<|END_ACTION|>"#;
380
381        let (_, tools) = parser.parse_complete(input).await.unwrap();
382        assert_eq!(tools.len(), 1);
383
384        let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
385        assert_eq!(
386            args["config"]["nested"]["value"],
387            serde_json::json!([1, 2, 3])
388        );
389    }
390
391    #[tokio::test]
392    async fn test_text_markers_cleaned() {
393        let parser = CohereParser::new();
394        let input = r#"<|START_TEXT|>Some intro<|END_TEXT|>
395<|START_ACTION|>{"tool_name": "test", "parameters": {}}<|END_ACTION|>
396<|START_TEXT|>Conclusion<|END_TEXT|>"#;
397
398        let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
399        assert_eq!(tools.len(), 1);
400        assert!(normal_text.contains("Some intro"));
401        assert!(normal_text.contains("Conclusion"));
402        assert!(!normal_text.contains("<|START_TEXT|>"));
403        assert!(!normal_text.contains("<|END_TEXT|>"));
404    }
405
406    #[tokio::test]
407    async fn test_malformed_json() {
408        let parser = CohereParser::new();
409        let input = r#"<|START_ACTION|>{"tool_name": invalid}<|END_ACTION|>"#;
410
411        let (_, tools) = parser.parse_complete(input).await.unwrap();
412        // Should gracefully handle malformed JSON
413        assert_eq!(tools.len(), 0);
414    }
415}