umf/chatml/
mod.rs

1//! ChatML message formatter for simpaticoder.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use tiktoken_rs::cl100k_base;
7
8/// ChatML message roles.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum MessageRole {
12    System,
13    User,
14    Assistant,
15    Tool,
16}
17
18impl std::fmt::Display for MessageRole {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            MessageRole::System => write!(f, "system"),
22            MessageRole::User => write!(f, "user"),
23            MessageRole::Assistant => write!(f, "assistant"),
24            MessageRole::Tool => write!(f, "tool"),
25        }
26    }
27}
28
29/// Represents a single ChatML message.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ChatMLMessage {
32    pub role: MessageRole,
33    pub content: String,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub name: Option<String>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub tool_call_id: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub tool_calls: Option<Vec<crate::ToolCall>>,
40}
41
42impl ChatMLMessage {
43    /// Initialize ChatML message.
44    ///
45    /// # Arguments
46    /// * `role` - Message role (system, user, assistant).
47    /// * `content` - Message content.
48    /// * `name` - Optional name for the message sender.
49    pub fn new(role: MessageRole, content: String, name: Option<String>) -> Self {
50        Self {
51            role,
52            content,
53            name,
54            tool_call_id: None,
55            tool_calls: None,
56        }
57    }
58
59    /// Initialize ChatML tool message.
60    ///
61    /// # Arguments
62    /// * `content` - Tool result content.
63    /// * `tool_call_id` - ID of the tool call this message is responding to.
64    /// * `name` - Name of the tool that was called.
65    pub fn new_tool(content: String, tool_call_id: String, name: String) -> Self {
66        Self {
67            role: MessageRole::Tool,
68            content,
69            name: Some(name),
70            tool_call_id: Some(tool_call_id),
71            tool_calls: None,
72        }
73    }
74
75    /// Initialize ChatML assistant message with tool calls.
76    ///
77    /// # Arguments
78    /// * `content` - Assistant message content (can be empty for tool-only responses).
79    /// * `tool_calls` - Vector of tool calls made by the assistant.
80    pub fn new_assistant_with_tool_calls(
81        content: String,
82        tool_calls: Vec<crate::ToolCall>,
83    ) -> Self {
84        Self {
85            role: MessageRole::Assistant,
86            content,
87            name: None,
88            tool_call_id: None,
89            tool_calls: Some(tool_calls),
90        }
91    }
92
93    /// Convert message to dictionary format for OpenAI API.
94    pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
95        let mut message = HashMap::new();
96        message.insert(
97            "role".to_string(),
98            serde_json::Value::String(self.role.to_string()),
99        );
100        message.insert(
101            "content".to_string(),
102            serde_json::Value::String(self.content.clone()),
103        );
104
105        if let Some(name) = &self.name {
106            message.insert("name".to_string(), serde_json::Value::String(name.clone()));
107        }
108
109        if let Some(tool_call_id) = &self.tool_call_id {
110            message.insert(
111                "tool_call_id".to_string(),
112                serde_json::Value::String(tool_call_id.clone()),
113            );
114        }
115
116        if let Some(tool_calls) = &self.tool_calls {
117            let tool_calls_json = serde_json::to_value(tool_calls)
118                .unwrap_or_else(|_| serde_json::Value::Array(vec![]));
119            message.insert("tool_calls".to_string(), tool_calls_json);
120        }
121
122        message
123    }
124
125    /// Convert message to ChatML string format.
126    pub fn to_chatml_string(&self) -> String {
127        let name_part = if let Some(name) = &self.name {
128            format!(" name={}", name)
129        } else {
130            String::new()
131        };
132
133        format!(
134            "<|im_start|>{}{}\n{}\n<|im_end|>",
135            self.role, name_part, self.content
136        )
137    }
138}
139
140/// Formats messages in ChatML format for simpaticoder.
141#[derive(Debug, Clone)]
142pub struct ChatMLFormatter {
143    messages: Vec<ChatMLMessage>,
144}
145
146impl ChatMLFormatter {
147    /// Initialize ChatML formatter.
148    pub fn new() -> Self {
149        Self {
150            messages: Vec::new(),
151        }
152    }
153
154    /// Add system message.
155    ///
156    /// # Arguments
157    /// * `content` - System message content.
158    /// * `name` - Optional name for the system.
159    pub fn add_system_message(&mut self, content: String, name: Option<String>) -> &mut Self {
160        self.messages
161            .push(ChatMLMessage::new(MessageRole::System, content, name));
162        self
163    }
164
165    /// Add user message.
166    ///
167    /// # Arguments
168    /// * `content` - User message content.
169    /// * `name` - Optional name for the user.
170    pub fn add_user_message(&mut self, content: String, name: Option<String>) -> &mut Self {
171        self.messages
172            .push(ChatMLMessage::new(MessageRole::User, content, name));
173        self
174    }
175
176    /// Add assistant message.
177    ///
178    /// # Arguments
179    /// * `content` - Assistant message content.
180    /// * `name` - Optional name for the assistant.
181    pub fn add_assistant_message(&mut self, content: String, name: Option<String>) -> &mut Self {
182        self.messages
183            .push(ChatMLMessage::new(MessageRole::Assistant, content, name));
184        self
185    }
186
187    /// Add assistant message with tool calls.
188    ///
189    /// # Arguments
190    /// * `content` - Assistant message content (can be empty for tool-only responses).
191    /// * `tool_calls` - Vector of tool calls made by the assistant.
192    pub fn add_assistant_message_with_tool_calls(
193        &mut self,
194        content: String,
195        tool_calls: Vec<crate::ToolCall>,
196    ) -> &mut Self {
197        self.messages
198            .push(ChatMLMessage::new_assistant_with_tool_calls(
199                content, tool_calls,
200            ));
201        self
202    }
203
204    /// Add tool message.
205    ///
206    /// # Arguments
207    /// * `content` - Tool result content.
208    /// * `tool_call_id` - ID of the tool call this message is responding to.
209    /// * `name` - Name of the tool that was called.
210    pub fn add_tool_message(
211        &mut self,
212        content: String,
213        tool_call_id: String,
214        name: String,
215    ) -> &mut Self {
216        self.messages
217            .push(ChatMLMessage::new_tool(content, tool_call_id, name));
218        self
219    }
220
221    /// Add combined tool results message.
222    /// This is a temporary method for compatibility with current code structure.
223    ///
224    /// # Arguments
225    /// * `content` - Combined tool results content.
226    /// * `name` - Optional name for the tool results message.
227    pub fn add_tool_results_message(&mut self, content: String, name: Option<String>) -> &mut Self {
228        // For now, we'll use a generic tool_call_id for combined results
229        // This should be refactored to use individual tool messages in the future
230        self.messages.push(ChatMLMessage::new_tool(
231            content,
232            "combined_tool_results".to_string(),
233            name.unwrap_or_else(|| "tool_results".to_string()),
234        ));
235        self
236    }
237
238    /// Convert messages to OpenAI API format.
239    ///
240    /// # Returns
241    /// Vector of message HashMaps.
242    pub fn to_openai_format(&self) -> Vec<HashMap<String, serde_json::Value>> {
243        self.messages.iter().map(|msg| msg.to_dict()).collect()
244    }
245
246    /// Convert all messages to ChatML string format.
247    ///
248    /// # Returns
249    /// Full conversation in ChatML format.
250    pub fn to_chatml_string(&self) -> String {
251        self.messages
252            .iter()
253            .map(|msg| msg.to_chatml_string())
254            .collect::<Vec<_>>()
255            .join("\n")
256    }
257
258    /// Clear all messages.
259    pub fn clear(&mut self) -> &mut Self {
260        self.messages.clear();
261        self
262    }
263
264    /// Limit the number of messages to prevent context overflow.
265    ///
266    /// # Arguments
267    /// * `max_messages` - Maximum number of messages to keep.
268    pub fn limit_history(&mut self, max_messages: usize) -> &mut Self {
269        if self.messages.len() > max_messages {
270            // Keep the first message (system) and the most recent messages
271            let system_message = self.messages.first().cloned();
272            let recent_messages = self
273                .messages
274                .iter()
275                .rev()
276                .take(max_messages - 1)
277                .rev()
278                .cloned()
279                .collect::<Vec<_>>();
280
281            self.messages = if let Some(system) = system_message {
282                std::iter::once(system).chain(recent_messages).collect()
283            } else {
284                recent_messages
285            };
286        }
287        self
288    }
289
290    /// Get number of messages.
291    pub fn get_message_count(&self) -> usize {
292        self.messages.len()
293    }
294
295    /// Get the last message.
296    pub fn get_last_message(&self) -> Option<&ChatMLMessage> {
297        self.messages.last()
298    }
299
300    /// Get all messages.
301    pub fn get_messages(&self) -> &Vec<ChatMLMessage> {
302        &self.messages
303    }
304
305    /// Format a thought and command in the expected format.
306    ///
307    /// # Arguments
308    /// * `thought` - Brief reasoning explanation.
309    /// * `command` - Bash command to execute.
310    ///
311    /// # Returns
312    /// Formatted thought and command string.
313    pub fn format_thought_command(&self, thought: &str, command: &str) -> String {
314        format!("THOUGHT: {}\n\n```bash\n{}\n```", thought, command)
315    }
316
317    /// Replace template variables in a string with actual values.
318    ///
319    /// # Arguments
320    /// * `template` - Template string with {variable} placeholders.
321    /// * `variables` - HashMap of variable names to values.
322    ///
323    /// # Returns
324    /// String with variables replaced.
325    pub fn replace_template_variables(
326        &self,
327        template: &str,
328        variables: &HashMap<String, String>,
329    ) -> String {
330        let mut result = template.to_string();
331        for (key, value) in variables {
332            let placeholder = format!("{{{}}}", key);
333            result = result.replace(&placeholder, value);
334        }
335        result
336    }
337
338    /// Load and process a template file with variable replacement.
339    ///
340    /// # Arguments
341    /// * `template_path` - Path to the template file.
342    /// * `variables` - HashMap of variable names to values.
343    ///
344    /// # Returns
345    /// Processed template content or error.
346    pub fn process_template(
347        &self,
348        template_path: &str,
349        variables: &HashMap<String, String>,
350    ) -> Result<String, Box<dyn std::error::Error>> {
351        let template_content = std::fs::read_to_string(template_path)?;
352        Ok(self.replace_template_variables(&template_content, variables))
353    }
354
355    /// Validate that all messages have required fields.
356    ///
357    /// # Returns
358    /// True if all messages are valid, false otherwise.
359    pub fn validate_messages(&self) -> bool {
360        for message in &self.messages {
361            // Allow empty content for assistant messages with tool calls (OpenAI API requirement)
362            if message.content.is_empty() && message.tool_calls.is_none() {
363                return false;
364            }
365            // System messages should have names for simpaticoder
366            // Assistant messages should have names UNLESS they have tool_calls (OpenAI API pattern)
367            if message.role == MessageRole::System {
368                if message.name.is_none() {
369                    return false;
370                }
371            }
372            if message.role == MessageRole::Assistant {
373                // Assistant messages with tool_calls don't need names (per OpenAI API spec)
374                if message.tool_calls.is_none() && message.name.is_none() {
375                    return false;
376                }
377            }
378            // Tool messages must have tool_call_id and name
379            if matches!(message.role, MessageRole::Tool) {
380                if message.tool_call_id.is_none() || message.name.is_none() {
381                    return false;
382                }
383            }
384        }
385        true
386    }
387    /// Count the number of tokens in the current conversation.
388    ///
389    /// # Returns
390    /// Number of tokens, or 0 if tokenization fails.
391    pub fn count_tokens(&self) -> usize {
392        match cl100k_base() {
393            Ok(bpe) => {
394                let chatml_string = self.to_chatml_string();
395                let tokens = bpe.encode_with_special_tokens(&chatml_string);
396                tokens.len()
397            }
398            Err(_) => 0,
399        }
400    }
401}
402
403impl Default for ChatMLFormatter {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409#[cfg(test)]
410mod tests;