1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum MessageRole {
15 System,
16 User,
17 Assistant,
18 Tool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolCall {
24 pub id: String,
26 pub name: String,
28 pub arguments: String,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ConversationMessage {
35 pub role: MessageRole,
37 pub content: String,
39 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub tool_calls: Vec<ToolCall>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub tool_call_id: Option<String>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub tool_name: Option<String>,
48}
49
50impl ConversationMessage {
51 pub fn system(content: impl Into<String>) -> Self {
53 Self {
54 role: MessageRole::System,
55 content: content.into(),
56 tool_calls: Vec::new(),
57 tool_call_id: None,
58 tool_name: None,
59 }
60 }
61
62 pub fn user(content: impl Into<String>) -> Self {
64 Self {
65 role: MessageRole::User,
66 content: content.into(),
67 tool_calls: Vec::new(),
68 tool_call_id: None,
69 tool_name: None,
70 }
71 }
72
73 pub fn assistant(content: impl Into<String>) -> Self {
75 Self {
76 role: MessageRole::Assistant,
77 content: content.into(),
78 tool_calls: Vec::new(),
79 tool_call_id: None,
80 tool_name: None,
81 }
82 }
83
84 pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
86 Self {
87 role: MessageRole::Assistant,
88 content: String::new(),
89 tool_calls,
90 tool_call_id: None,
91 tool_name: None,
92 }
93 }
94
95 pub fn tool_result(
97 tool_call_id: impl Into<String>,
98 tool_name: impl Into<String>,
99 content: impl Into<String>,
100 ) -> Self {
101 Self {
102 role: MessageRole::Tool,
103 content: content.into(),
104 tool_calls: Vec::new(),
105 tool_call_id: Some(tool_call_id.into()),
106 tool_name: Some(tool_name.into()),
107 }
108 }
109
110 pub fn estimate_tokens(&self) -> usize {
117 let mut chars = self.content.len();
118 for tc in &self.tool_calls {
119 chars += tc.name.len() + tc.arguments.len() + tc.id.len() + 30; }
122 if let Some(ref id) = self.tool_call_id {
123 chars += id.len() + 20; }
125 (chars * 10 / 33).max(1) + 7
128 }
129}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct Conversation {
134 messages: Vec<ConversationMessage>,
135}
136
137impl Conversation {
138 pub fn new() -> Self {
140 Self {
141 messages: Vec::new(),
142 }
143 }
144
145 pub fn with_system(system_prompt: impl Into<String>) -> Self {
147 Self {
148 messages: vec![ConversationMessage::system(system_prompt)],
149 }
150 }
151
152 pub fn push(&mut self, message: ConversationMessage) {
154 self.messages.push(message);
155 }
156
157 pub fn messages(&self) -> &[ConversationMessage] {
159 &self.messages
160 }
161
162 pub fn len(&self) -> usize {
164 self.messages.len()
165 }
166
167 pub fn is_empty(&self) -> bool {
169 self.messages.is_empty()
170 }
171
172 pub fn estimate_tokens(&self) -> usize {
174 self.messages.iter().map(|m| m.estimate_tokens()).sum()
175 }
176
177 pub fn system_message(&self) -> Option<&ConversationMessage> {
179 self.messages.iter().find(|m| m.role == MessageRole::System)
180 }
181
182 pub fn last_assistant_message(&self) -> Option<&ConversationMessage> {
184 self.messages
185 .iter()
186 .rev()
187 .find(|m| m.role == MessageRole::Assistant)
188 }
189
190 pub fn to_openai_messages(&self) -> Vec<serde_json::Value> {
195 self.messages
196 .iter()
197 .map(|msg| {
198 let mut obj = serde_json::Map::new();
199 let role_str = match msg.role {
200 MessageRole::System => "system",
201 MessageRole::User => "user",
202 MessageRole::Assistant => "assistant",
203 MessageRole::Tool => "tool",
204 };
205 obj.insert("role".into(), serde_json::Value::String(role_str.into()));
206
207 if !msg.content.is_empty() {
208 obj.insert(
209 "content".into(),
210 serde_json::Value::String(msg.content.clone()),
211 );
212 } else if msg.role != MessageRole::Assistant {
213 obj.insert("content".into(), serde_json::Value::String(String::new()));
215 }
216
217 if !msg.tool_calls.is_empty() {
218 let tool_calls: Vec<serde_json::Value> = msg
219 .tool_calls
220 .iter()
221 .map(|tc| {
222 serde_json::json!({
223 "id": tc.id,
224 "type": "function",
225 "function": {
226 "name": tc.name,
227 "arguments": tc.arguments,
228 }
229 })
230 })
231 .collect();
232 obj.insert("tool_calls".into(), serde_json::Value::Array(tool_calls));
233 }
234
235 if let Some(ref id) = msg.tool_call_id {
236 obj.insert("tool_call_id".into(), serde_json::Value::String(id.clone()));
237 }
238
239 serde_json::Value::Object(obj)
240 })
241 .collect()
242 }
243
244 pub fn to_anthropic_messages(&self) -> (Option<String>, Vec<serde_json::Value>) {
249 let system = self
250 .messages
251 .iter()
252 .find(|m| m.role == MessageRole::System)
253 .map(|m| m.content.clone());
254
255 let mut raw_messages: Vec<serde_json::Value> = Vec::new();
259
260 for msg in self
261 .messages
262 .iter()
263 .filter(|m| m.role != MessageRole::System)
264 {
265 let role_str = match msg.role {
266 MessageRole::User | MessageRole::Tool => "user",
267 MessageRole::Assistant => "assistant",
268 MessageRole::System => unreachable!(),
269 };
270
271 let serialized = if msg.role == MessageRole::Tool {
272 serde_json::json!({
274 "role": "user",
275 "content": [{
276 "type": "tool_result",
277 "tool_use_id": msg.tool_call_id.as_deref().unwrap_or(""),
278 "content": msg.content,
279 }]
280 })
281 } else if !msg.tool_calls.is_empty() {
282 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
284 if !msg.content.is_empty() {
285 content_blocks.push(serde_json::json!({
286 "type": "text",
287 "text": msg.content,
288 }));
289 }
290 for tc in &msg.tool_calls {
291 let args: serde_json::Value =
292 serde_json::from_str(&tc.arguments).unwrap_or(serde_json::json!({}));
293 content_blocks.push(serde_json::json!({
294 "type": "tool_use",
295 "id": tc.id,
296 "name": tc.name,
297 "input": args,
298 }));
299 }
300 serde_json::json!({
301 "role": role_str,
302 "content": content_blocks,
303 })
304 } else {
305 serde_json::json!({
306 "role": role_str,
307 "content": msg.content,
308 })
309 };
310
311 if let Some(last) = raw_messages.last_mut() {
315 let last_role = last.get("role").and_then(|r| r.as_str()).unwrap_or("");
316 if last_role == role_str {
317 let prev_content = last.get_mut("content").unwrap();
319 let new_content = serialized.get("content").unwrap();
320
321 let prev_arr = if prev_content.is_array() {
323 prev_content.as_array_mut().unwrap()
324 } else {
325 let text = prev_content.as_str().unwrap_or("").to_string();
327 *prev_content = serde_json::json!([{"type": "text", "text": text}]);
328 prev_content.as_array_mut().unwrap()
329 };
330
331 if new_content.is_array() {
332 prev_arr.extend(new_content.as_array().unwrap().iter().cloned());
333 } else {
334 let text = new_content.as_str().unwrap_or("").to_string();
335 prev_arr.push(serde_json::json!({"type": "text", "text": text}));
336 }
337
338 continue;
339 }
340 }
341
342 raw_messages.push(serialized);
343 }
344
345 (system, raw_messages)
346 }
347
348 pub fn truncate_to_budget(&mut self, max_tokens: usize) {
353 if self.estimate_tokens() <= max_tokens {
354 return;
355 }
356
357 let system_msg = if self
358 .messages
359 .first()
360 .is_some_and(|m| m.role == MessageRole::System)
361 {
362 Some(self.messages[0].clone())
363 } else {
364 None
365 };
366
367 let system_tokens = system_msg.as_ref().map_or(0, |m| m.estimate_tokens());
368 let remaining_budget = max_tokens.saturating_sub(system_tokens);
369
370 let start_idx = if system_msg.is_some() { 1 } else { 0 };
372 let non_system: Vec<ConversationMessage> = self.messages.drain(start_idx..).rev().collect();
373
374 let mut kept = Vec::new();
375 let mut used_tokens = 0;
376 for msg in non_system {
377 let msg_tokens = msg.estimate_tokens();
378 if used_tokens + msg_tokens > remaining_budget {
379 break;
380 }
381 used_tokens += msg_tokens;
382 kept.push(msg);
383 }
384 kept.reverse();
385
386 self.messages.clear();
387 if let Some(sys) = system_msg {
388 self.messages.push(sys);
389 }
390 self.messages.extend(kept);
391 }
392
393 pub fn inject_knowledge_context(&mut self, context: impl Into<String>) {
396 let marker = "[KNOWLEDGE_CONTEXT]";
397 let content = format!("{}\n{}", marker, context.into());
398 let msg = ConversationMessage::system(content);
399
400 if let Some(pos) = self
402 .messages
403 .iter()
404 .position(|m| m.role == MessageRole::System && m.content.starts_with(marker))
405 {
406 self.messages[pos] = msg;
407 } else {
408 let insert_pos = if self
410 .messages
411 .first()
412 .is_some_and(|m| m.role == MessageRole::System)
413 {
414 1
415 } else {
416 0
417 };
418 self.messages.insert(insert_pos, msg);
419 }
420 }
421
422 pub fn metadata(&self) -> HashMap<String, String> {
424 let mut meta = HashMap::new();
425 meta.insert("message_count".into(), self.messages.len().to_string());
426 meta.insert(
427 "estimated_tokens".into(),
428 self.estimate_tokens().to_string(),
429 );
430 meta.insert(
431 "has_system".into(),
432 self.system_message().is_some().to_string(),
433 );
434 let tool_call_count: usize = self.messages.iter().map(|m| m.tool_calls.len()).sum();
435 meta.insert("tool_call_count".into(), tool_call_count.to_string());
436 let tool_result_count = self
437 .messages
438 .iter()
439 .filter(|m| m.role == MessageRole::Tool)
440 .count();
441 meta.insert("tool_result_count".into(), tool_result_count.to_string());
442 meta
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_conversation_creation() {
452 let conv = Conversation::with_system("You are a helpful assistant.");
453 assert_eq!(conv.len(), 1);
454 assert!(!conv.is_empty());
455 assert!(conv.system_message().is_some());
456 }
457
458 #[test]
459 fn test_message_constructors() {
460 let sys = ConversationMessage::system("system");
461 assert_eq!(sys.role, MessageRole::System);
462 assert_eq!(sys.content, "system");
463
464 let user = ConversationMessage::user("hello");
465 assert_eq!(user.role, MessageRole::User);
466
467 let asst = ConversationMessage::assistant("hi there");
468 assert_eq!(asst.role, MessageRole::Assistant);
469
470 let tool = ConversationMessage::tool_result("call_1", "search", "results here");
471 assert_eq!(tool.role, MessageRole::Tool);
472 assert_eq!(tool.tool_call_id.as_deref(), Some("call_1"));
473 assert_eq!(tool.tool_name.as_deref(), Some("search"));
474 }
475
476 #[test]
477 fn test_openai_serialization_roundtrip() {
478 let mut conv = Conversation::with_system("You are a test agent.");
479 conv.push(ConversationMessage::user("Search for rust crates"));
480 conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
481 id: "call_1".into(),
482 name: "web_search".into(),
483 arguments: r#"{"query":"rust crates"}"#.into(),
484 }]));
485 conv.push(ConversationMessage::tool_result(
486 "call_1",
487 "web_search",
488 "Found: serde, tokio, reqwest",
489 ));
490 conv.push(ConversationMessage::assistant(
491 "I found serde, tokio, and reqwest.",
492 ));
493
494 let openai_msgs = conv.to_openai_messages();
495 assert_eq!(openai_msgs.len(), 5);
496
497 assert_eq!(openai_msgs[0]["role"], "system");
499 assert_eq!(openai_msgs[0]["content"], "You are a test agent.");
500
501 assert_eq!(openai_msgs[1]["role"], "user");
503
504 assert_eq!(openai_msgs[2]["role"], "assistant");
506 assert!(openai_msgs[2]["tool_calls"].is_array());
507 let tool_calls = openai_msgs[2]["tool_calls"].as_array().unwrap();
508 assert_eq!(tool_calls.len(), 1);
509 assert_eq!(tool_calls[0]["function"]["name"], "web_search");
510
511 assert_eq!(openai_msgs[3]["role"], "tool");
513 assert_eq!(openai_msgs[3]["tool_call_id"], "call_1");
514
515 assert_eq!(openai_msgs[4]["role"], "assistant");
517 }
518
519 #[test]
520 fn test_anthropic_serialization() {
521 let mut conv = Conversation::with_system("System prompt here.");
522 conv.push(ConversationMessage::user("Hello"));
523 conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
524 id: "tu_1".into(),
525 name: "calculator".into(),
526 arguments: r#"{"expr":"2+2"}"#.into(),
527 }]));
528 conv.push(ConversationMessage::tool_result("tu_1", "calculator", "4"));
529 conv.push(ConversationMessage::assistant("The result is 4."));
530
531 let (system, messages) = conv.to_anthropic_messages();
532 assert_eq!(system.as_deref(), Some("System prompt here."));
533 assert_eq!(messages.len(), 4);
535
536 assert_eq!(messages[0]["role"], "user");
538 assert_eq!(messages[0]["content"], "Hello");
539
540 assert_eq!(messages[1]["role"], "assistant");
542 let content = messages[1]["content"].as_array().unwrap();
543 assert_eq!(content[0]["type"], "tool_use");
544 assert_eq!(content[0]["name"], "calculator");
545
546 assert_eq!(messages[2]["role"], "user");
548 let result_content = messages[2]["content"].as_array().unwrap();
549 assert_eq!(result_content[0]["type"], "tool_result");
550 assert_eq!(result_content[0]["tool_use_id"], "tu_1");
551
552 assert_eq!(messages[3]["role"], "assistant");
554 }
555
556 #[test]
557 fn test_token_estimation() {
558 let msg = ConversationMessage::user("Hello, world!"); let tokens = msg.estimate_tokens();
560 assert_eq!(tokens, 10);
562 }
563
564 #[test]
565 fn test_conversation_token_estimation() {
566 let mut conv = Conversation::with_system("Be helpful.");
567 conv.push(ConversationMessage::user("Hi"));
568 conv.push(ConversationMessage::assistant("Hello!"));
569 let total = conv.estimate_tokens();
570 assert!(total > 0);
571 }
572
573 #[test]
574 fn test_truncate_to_budget() {
575 let mut conv = Conversation::with_system("sys");
576 for i in 0..20 {
577 conv.push(ConversationMessage::user(format!(
578 "Message number {} with some extra text to take up tokens",
579 i
580 )));
581 conv.push(ConversationMessage::assistant(format!("Reply {}", i)));
582 }
583
584 let original_len = conv.len();
585 assert!(original_len > 10);
586
587 conv.truncate_to_budget(100);
588 assert!(conv.len() < original_len);
589 assert_eq!(conv.messages()[0].role, MessageRole::System);
591 assert!(conv.estimate_tokens() <= 100);
592 }
593
594 #[test]
595 fn test_metadata() {
596 let mut conv = Conversation::with_system("sys");
597 conv.push(ConversationMessage::user("hi"));
598 conv.push(ConversationMessage::assistant_tool_calls(vec![
599 ToolCall {
600 id: "c1".into(),
601 name: "t1".into(),
602 arguments: "{}".into(),
603 },
604 ToolCall {
605 id: "c2".into(),
606 name: "t2".into(),
607 arguments: "{}".into(),
608 },
609 ]));
610 conv.push(ConversationMessage::tool_result("c1", "t1", "ok"));
611 conv.push(ConversationMessage::tool_result("c2", "t2", "ok"));
612
613 let meta = conv.metadata();
614 assert_eq!(meta["message_count"], "5");
615 assert_eq!(meta["has_system"], "true");
616 assert_eq!(meta["tool_call_count"], "2");
617 assert_eq!(meta["tool_result_count"], "2");
618 }
619
620 #[test]
621 fn test_last_assistant_message() {
622 let mut conv = Conversation::new();
623 assert!(conv.last_assistant_message().is_none());
624
625 conv.push(ConversationMessage::user("hi"));
626 conv.push(ConversationMessage::assistant("first"));
627 conv.push(ConversationMessage::user("more"));
628 conv.push(ConversationMessage::assistant("second"));
629
630 assert_eq!(conv.last_assistant_message().unwrap().content, "second");
631 }
632
633 #[test]
634 fn test_inject_knowledge_context_after_system() {
635 let mut conv = Conversation::with_system("You are helpful.");
636 conv.push(ConversationMessage::user("hello"));
637 conv.inject_knowledge_context("Some knowledge here");
638
639 assert_eq!(conv.len(), 3);
640 assert_eq!(conv.messages()[0].role, MessageRole::System);
642 assert_eq!(conv.messages()[0].content, "You are helpful.");
643 assert!(conv.messages()[1].content.contains("[KNOWLEDGE_CONTEXT]"));
644 assert!(conv.messages()[1].content.contains("Some knowledge here"));
645 assert_eq!(conv.messages()[2].role, MessageRole::User);
646 }
647
648 #[test]
649 fn test_inject_knowledge_context_replaces_existing() {
650 let mut conv = Conversation::with_system("System prompt");
651 conv.inject_knowledge_context("First knowledge");
652 conv.inject_knowledge_context("Updated knowledge");
653
654 let knowledge_msgs: Vec<_> = conv
656 .messages()
657 .iter()
658 .filter(|m| m.content.contains("[KNOWLEDGE_CONTEXT]"))
659 .collect();
660 assert_eq!(knowledge_msgs.len(), 1);
661 assert!(knowledge_msgs[0].content.contains("Updated knowledge"));
662 }
663
664 #[test]
665 fn test_inject_knowledge_context_no_system_message() {
666 let mut conv = Conversation::new();
667 conv.push(ConversationMessage::user("hello"));
668 conv.inject_knowledge_context("Knowledge without system");
669
670 assert_eq!(conv.len(), 2);
671 assert!(conv.messages()[0].content.contains("[KNOWLEDGE_CONTEXT]"));
673 assert_eq!(conv.messages()[1].role, MessageRole::User);
674 }
675
676 #[test]
677 fn test_serde_roundtrip() {
678 let mut conv = Conversation::with_system("test");
679 conv.push(ConversationMessage::user("hello"));
680 conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
681 id: "tc1".into(),
682 name: "search".into(),
683 arguments: r#"{"q":"test"}"#.into(),
684 }]));
685
686 let json = serde_json::to_string(&conv).unwrap();
687 let restored: Conversation = serde_json::from_str(&json).unwrap();
688 assert_eq!(restored.len(), conv.len());
689 assert_eq!(restored.messages()[2].tool_calls[0].name, "search");
690 }
691}