1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use tiktoken_rs::cl100k_base;
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum MessageRole {
12 System,
13 User,
14 Assistant,
15 Tool,
16}
17
18impl std::fmt::Display for MessageRole {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 MessageRole::System => write!(f, "system"),
22 MessageRole::User => write!(f, "user"),
23 MessageRole::Assistant => write!(f, "assistant"),
24 MessageRole::Tool => write!(f, "tool"),
25 }
26 }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ChatMLMessage {
32 pub role: MessageRole,
33 pub content: String,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub name: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub tool_call_id: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub tool_calls: Option<Vec<crate::ToolCall>>,
40}
41
42impl ChatMLMessage {
43 pub fn new(role: MessageRole, content: String, name: Option<String>) -> Self {
50 Self {
51 role,
52 content,
53 name,
54 tool_call_id: None,
55 tool_calls: None,
56 }
57 }
58
59 pub fn new_tool(content: String, tool_call_id: String, name: String) -> Self {
66 Self {
67 role: MessageRole::Tool,
68 content,
69 name: Some(name),
70 tool_call_id: Some(tool_call_id),
71 tool_calls: None,
72 }
73 }
74
75 pub fn new_assistant_with_tool_calls(
81 content: String,
82 tool_calls: Vec<crate::ToolCall>,
83 ) -> Self {
84 Self {
85 role: MessageRole::Assistant,
86 content,
87 name: None,
88 tool_call_id: None,
89 tool_calls: Some(tool_calls),
90 }
91 }
92
93 pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
95 let mut message = HashMap::new();
96 message.insert(
97 "role".to_string(),
98 serde_json::Value::String(self.role.to_string()),
99 );
100 message.insert(
101 "content".to_string(),
102 serde_json::Value::String(self.content.clone()),
103 );
104
105 if let Some(name) = &self.name {
106 message.insert("name".to_string(), serde_json::Value::String(name.clone()));
107 }
108
109 if let Some(tool_call_id) = &self.tool_call_id {
110 message.insert(
111 "tool_call_id".to_string(),
112 serde_json::Value::String(tool_call_id.clone()),
113 );
114 }
115
116 if let Some(tool_calls) = &self.tool_calls {
117 let tool_calls_json = serde_json::to_value(tool_calls)
118 .unwrap_or_else(|_| serde_json::Value::Array(vec![]));
119 message.insert("tool_calls".to_string(), tool_calls_json);
120 }
121
122 message
123 }
124
125 pub fn to_chatml_string(&self) -> String {
127 let name_part = if let Some(name) = &self.name {
128 format!(" name={}", name)
129 } else {
130 String::new()
131 };
132
133 format!(
134 "<|im_start|>{}{}\n{}\n<|im_end|>",
135 self.role, name_part, self.content
136 )
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct ChatMLFormatter {
143 messages: Vec<ChatMLMessage>,
144}
145
146impl ChatMLFormatter {
147 pub fn new() -> Self {
149 Self {
150 messages: Vec::new(),
151 }
152 }
153
154 pub fn add_system_message(&mut self, content: String, name: Option<String>) -> &mut Self {
160 self.messages
161 .push(ChatMLMessage::new(MessageRole::System, content, name));
162 self
163 }
164
165 pub fn add_user_message(&mut self, content: String, name: Option<String>) -> &mut Self {
171 self.messages
172 .push(ChatMLMessage::new(MessageRole::User, content, name));
173 self
174 }
175
176 pub fn add_assistant_message(&mut self, content: String, name: Option<String>) -> &mut Self {
182 self.messages
183 .push(ChatMLMessage::new(MessageRole::Assistant, content, name));
184 self
185 }
186
187 pub fn add_assistant_message_with_tool_calls(
193 &mut self,
194 content: String,
195 tool_calls: Vec<crate::ToolCall>,
196 ) -> &mut Self {
197 self.messages
198 .push(ChatMLMessage::new_assistant_with_tool_calls(
199 content, tool_calls,
200 ));
201 self
202 }
203
204 pub fn add_tool_message(
211 &mut self,
212 content: String,
213 tool_call_id: String,
214 name: String,
215 ) -> &mut Self {
216 self.messages
217 .push(ChatMLMessage::new_tool(content, tool_call_id, name));
218 self
219 }
220
221 pub fn add_tool_results_message(&mut self, content: String, name: Option<String>) -> &mut Self {
228 self.messages.push(ChatMLMessage::new_tool(
231 content,
232 "combined_tool_results".to_string(),
233 name.unwrap_or_else(|| "tool_results".to_string()),
234 ));
235 self
236 }
237
238 pub fn to_openai_format(&self) -> Vec<HashMap<String, serde_json::Value>> {
243 self.messages.iter().map(|msg| msg.to_dict()).collect()
244 }
245
246 pub fn to_chatml_string(&self) -> String {
251 self.messages
252 .iter()
253 .map(|msg| msg.to_chatml_string())
254 .collect::<Vec<_>>()
255 .join("\n")
256 }
257
258 pub fn clear(&mut self) -> &mut Self {
260 self.messages.clear();
261 self
262 }
263
264 pub fn limit_history(&mut self, max_messages: usize) -> &mut Self {
269 if self.messages.len() > max_messages {
270 let system_message = self.messages.first().cloned();
272 let recent_messages = self
273 .messages
274 .iter()
275 .rev()
276 .take(max_messages - 1)
277 .rev()
278 .cloned()
279 .collect::<Vec<_>>();
280
281 self.messages = if let Some(system) = system_message {
282 std::iter::once(system).chain(recent_messages).collect()
283 } else {
284 recent_messages
285 };
286 }
287 self
288 }
289
290 pub fn get_message_count(&self) -> usize {
292 self.messages.len()
293 }
294
295 pub fn get_last_message(&self) -> Option<&ChatMLMessage> {
297 self.messages.last()
298 }
299
300 pub fn get_messages(&self) -> &Vec<ChatMLMessage> {
302 &self.messages
303 }
304
305 pub fn format_thought_command(&self, thought: &str, command: &str) -> String {
314 format!("THOUGHT: {}\n\n```bash\n{}\n```", thought, command)
315 }
316
317 pub fn replace_template_variables(
326 &self,
327 template: &str,
328 variables: &HashMap<String, String>,
329 ) -> String {
330 let mut result = template.to_string();
331 for (key, value) in variables {
332 let placeholder = format!("{{{}}}", key);
333 result = result.replace(&placeholder, value);
334 }
335 result
336 }
337
338 pub fn process_template(
347 &self,
348 template_path: &str,
349 variables: &HashMap<String, String>,
350 ) -> Result<String, Box<dyn std::error::Error>> {
351 let template_content = std::fs::read_to_string(template_path)?;
352 Ok(self.replace_template_variables(&template_content, variables))
353 }
354
355 pub fn validate_messages(&self) -> bool {
360 for message in &self.messages {
361 if message.content.is_empty() && message.tool_calls.is_none() {
363 return false;
364 }
365 if message.role == MessageRole::System {
368 if message.name.is_none() {
369 return false;
370 }
371 }
372 if message.role == MessageRole::Assistant {
373 if message.tool_calls.is_none() && message.name.is_none() {
375 return false;
376 }
377 }
378 if matches!(message.role, MessageRole::Tool) {
380 if message.tool_call_id.is_none() || message.name.is_none() {
381 return false;
382 }
383 }
384 }
385 true
386 }
387 pub fn count_tokens(&self) -> usize {
392 match cl100k_base() {
393 Ok(bpe) => {
394 let chatml_string = self.to_chatml_string();
395 let tokens = bpe.encode_with_special_tokens(&chatml_string);
396 tokens.len()
397 }
398 Err(_) => 0,
399 }
400 }
401}
402
403impl Default for ChatMLFormatter {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409#[cfg(test)]
410mod tests;