1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet};
3use tracing::debug;
4
5use super::message::{AssistantContent, Message, MessageData, UserContent};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageGraph {
9 pub messages: Vec<Message>,
10 pub active_message_id: Option<String>,
11 #[serde(default)]
12 pub compaction_summary_ids: HashSet<String>,
13}
14
15impl Default for MessageGraph {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl MessageGraph {
22 pub fn new() -> Self {
23 Self {
24 messages: Vec::new(),
25 active_message_id: None,
26 compaction_summary_ids: HashSet::new(),
27 }
28 }
29
30 pub fn add_message(&mut self, message: Message) {
31 self.active_message_id = Some(message.id().to_string());
32 self.messages.push(message);
33 }
34
35 pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
36 debug!(target: "message_graph::add_message", "Adding message: {:?}", message_data);
37 self.messages.push(Message {
38 data: message_data,
39 id: Message::generate_id("", Message::current_timestamp()),
40 timestamp: Message::current_timestamp(),
41 parent_message_id: self.active_message_id.clone(),
42 });
43 let last_index = self.messages.len().saturating_sub(1);
44 self.active_message_id = Some(self.messages[last_index].id().to_string());
45 &self.messages[last_index]
46 }
47
48 pub fn clear(&mut self) {
49 debug!(target:"message_graph::clear", "Clearing message graph");
50 self.messages.clear();
51 self.active_message_id = None;
52 }
53
54 pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
55 for message in &self.messages {
56 if let MessageData::Assistant { content, .. } = &message.data {
57 for content_block in content {
58 if let AssistantContent::ToolCall { tool_call, .. } = content_block
59 && tool_call.id == tool_id
60 {
61 return Some(tool_call.name.clone());
62 }
63 }
64 }
65 }
66 None
67 }
68
69 pub fn edit_message(
70 &mut self,
71 message_id: &str,
72 new_content: Vec<UserContent>,
73 ) -> Option<String> {
74 let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
75
76 if !matches!(&message_to_edit.data, MessageData::User { .. }) {
77 return None;
78 }
79
80 let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
81
82 let new_message_id = Message::generate_id("user", Message::current_timestamp());
83 let edited_message = Message {
84 data: MessageData::User {
85 content: new_content,
86 },
87 timestamp: Message::current_timestamp(),
88 id: new_message_id.clone(),
89 parent_message_id: parent_id,
90 };
91
92 self.messages.push(edited_message);
93 self.active_message_id = Some(new_message_id.clone());
94
95 Some(new_message_id)
96 }
97
98 pub fn update_command_execution(
99 &mut self,
100 message_id: &str,
101 command: String,
102 stdout: String,
103 stderr: String,
104 exit_code: i32,
105 ) -> Option<Message> {
106 for message in &mut self.messages {
107 if message.id() != message_id {
108 continue;
109 }
110
111 if let MessageData::User { content } = &mut message.data {
112 *content = vec![UserContent::CommandExecution {
113 command,
114 stdout,
115 stderr,
116 exit_code,
117 }];
118 return Some(message.clone());
119 }
120
121 return None;
122 }
123
124 None
125 }
126
127 pub fn replace_message(&mut self, updated: Message) -> bool {
128 for message in &mut self.messages {
129 if message.id() == updated.id() {
130 *message = updated;
131 return true;
132 }
133 }
134
135 self.messages.push(updated);
136 false
137 }
138
139 pub fn checkout(&mut self, message_id: &str) -> bool {
140 if self.messages.iter().any(|m| m.id() == message_id) {
141 self.active_message_id = Some(message_id.to_string());
142 true
143 } else {
144 false
145 }
146 }
147
148 pub fn mark_compaction_summary(&mut self, id: String) {
149 self.compaction_summary_ids.insert(id);
150 }
151
152 pub fn get_active_thread(&self) -> Vec<&Message> {
153 if self.messages.is_empty() {
154 return Vec::new();
155 }
156
157 let head_id = if let Some(ref active_id) = self.active_message_id {
158 active_id.as_str()
159 } else {
160 self.messages.last().map_or("", |m| m.id())
161 };
162
163 let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
164 if current_msg.is_none() {
165 current_msg = self.messages.last();
166 }
167
168 let mut result = Vec::new();
169 let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
170
171 while let Some(msg) = current_msg {
172 result.push(msg);
173
174 if self.compaction_summary_ids.contains(msg.id()) {
177 break;
178 }
179
180 current_msg = if let Some(parent_id) = msg.parent_message_id() {
181 id_map.get(parent_id).copied()
182 } else {
183 None
184 };
185 }
186
187 result.reverse();
188
189 debug!(
190 "Active thread: [{}]",
191 result
192 .iter()
193 .map(|msg| msg.id())
194 .collect::<Vec<_>>()
195 .join(", ")
196 );
197 result
198 }
199
200 pub fn get_thread_messages(&self) -> Vec<&Message> {
201 self.get_active_thread()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
210 Message {
211 data: MessageData::User {
212 content: vec![UserContent::Text {
213 text: content.to_string(),
214 }],
215 },
216 timestamp: Message::current_timestamp(),
217 id: id.to_string(),
218 parent_message_id: parent_id.map(String::from),
219 }
220 }
221
222 fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
223 Message {
224 data: MessageData::Assistant {
225 content: vec![AssistantContent::Text {
226 text: content.to_string(),
227 }],
228 },
229 timestamp: Message::current_timestamp(),
230 id: id.to_string(),
231 parent_message_id: parent_id.map(String::from),
232 }
233 }
234
235 #[test]
236 fn test_editing_message_in_the_middle_of_conversation() {
237 let mut graph = MessageGraph::new();
238
239 let msg1 = create_user_message("msg1", None, "What is Rust?");
240 graph.add_message(msg1.clone());
241
242 let msg2 =
243 create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
244 graph.add_message(msg2.clone());
245
246 let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
247 graph.add_message(msg3.clone());
248
249 let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
250 graph.add_message(msg4.clone());
251
252 let edited_id = graph
253 .edit_message(
254 "msg1",
255 vec![UserContent::Text {
256 text: "What is Golang?".to_string(),
257 }],
258 )
259 .unwrap();
260
261 let messages_after_edit = graph.get_thread_messages();
262 let message_ids_after_edit: Vec<&str> =
263 messages_after_edit.iter().map(|m| m.id()).collect();
264
265 assert_eq!(
266 message_ids_after_edit.len(),
267 1,
268 "Active thread should only show the edited message"
269 );
270 assert_eq!(message_ids_after_edit[0], edited_id.as_str());
271
272 assert!(graph.messages.iter().any(|m| m.id() == "msg1"));
273 assert!(graph.messages.iter().any(|m| m.id() == "msg2"));
274 assert!(graph.messages.iter().any(|m| m.id() == "msg3"));
275 assert!(graph.messages.iter().any(|m| m.id() == "msg4"));
276
277 let msg5 = create_assistant_message(
278 "msg5",
279 Some(&edited_id),
280 "A systems programming language from Google.",
281 );
282 graph.add_message(msg5.clone());
283
284 let final_messages = graph.get_thread_messages();
285 let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
286
287 assert_eq!(
288 final_messages.len(),
289 2,
290 "Should have the edited message and the new response."
291 );
292 assert_eq!(final_message_ids[0], edited_id.as_str());
293 assert_eq!(final_message_ids[1], "msg5");
294 }
295
296 #[test]
297 fn test_get_thread_messages_after_edit() {
298 let mut graph = MessageGraph::new();
299
300 let msg1 = create_user_message("msg1", None, "hello");
301 graph.add_message(msg1.clone());
302
303 let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
304 graph.add_message(msg2.clone());
305
306 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
307 graph.add_message(msg3_original.clone());
308
309 let edited_id = graph
310 .edit_message(
311 "msg3_original",
312 vec![UserContent::Text {
313 text: "how are you".to_string(),
314 }],
315 )
316 .unwrap();
317
318 let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
319 graph.add_message(msg4.clone());
320
321 let thread_messages = graph.get_thread_messages();
322
323 let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
324
325 assert_eq!(
326 thread_message_ids.len(),
327 4,
328 "Should have 4 messages in the current thread"
329 );
330 assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
331 assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
332 assert!(
333 thread_message_ids.contains(&edited_id.as_str()),
334 "Should contain the edited message"
335 );
336 assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
337
338 assert!(
339 graph.messages.iter().any(|m| m.id() == "msg3_original"),
340 "Original message should still exist in message history"
341 );
342 }
343
344 #[test]
345 fn test_get_thread_messages_filters_other_branches() {
346 let mut graph = MessageGraph::new();
347
348 let msg1 = create_user_message("msg1", None, "hi");
349 graph.add_message(msg1.clone());
350
351 let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
352 graph.add_message(msg2.clone());
353
354 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
355 graph.add_message(msg3_original.clone());
356
357 let msg4_original =
358 create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
359 graph.add_message(msg4_original.clone());
360
361 let edited_id = graph
362 .edit_message(
363 "msg3_original",
364 vec![UserContent::Text {
365 text: "how are you".to_string(),
366 }],
367 )
368 .unwrap();
369
370 let msg4_new = create_assistant_message(
371 "msg4_new",
372 Some(&edited_id),
373 "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
374 );
375 graph.add_message(msg4_new.clone());
376
377 let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
378 graph.add_message(msg5.clone());
379
380 let thread_messages = graph.get_thread_messages();
381
382 let user_messages: Vec<String> = thread_messages
383 .iter()
384 .filter(|m| matches!(m.data, MessageData::User { .. }))
385 .map(|m| m.extract_text())
386 .collect();
387
388 println!("User messages seen: {user_messages:?}");
389
390 assert_eq!(
391 user_messages.len(),
392 3,
393 "Should have exactly 3 user messages"
394 );
395 assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
396 assert_eq!(
397 user_messages[1], "how are you",
398 "Second message should be 'how are you' (edited)"
399 );
400 assert_eq!(
401 user_messages[2], "what messages have I sent you?",
402 "Third message should be the question"
403 );
404
405 assert!(
406 !user_messages.contains(&"thanks".to_string()),
407 "Should NOT contain 'thanks' from the non-active branch"
408 );
409
410 assert!(
411 graph.messages.iter().any(|m| m.id() == "msg3_original"),
412 "Original 'thanks' message should still exist in message history"
413 );
414 }
415
416 #[test]
417 fn test_checkout_branch() {
418 let mut graph = MessageGraph::new();
419
420 let msg1 = create_user_message("msg1", None, "hello");
421 graph.add_message(msg1.clone());
422
423 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
424 graph.add_message(msg2.clone());
425
426 let edited_id = graph
427 .edit_message(
428 "msg1",
429 vec![UserContent::Text {
430 text: "goodbye".to_string(),
431 }],
432 )
433 .unwrap();
434
435 assert_eq!(graph.active_message_id, Some(edited_id.clone()));
436 let thread = graph.get_active_thread();
437 assert_eq!(thread.len(), 1);
438 assert_eq!(thread[0].id(), edited_id);
439
440 assert!(graph.checkout("msg2"));
441 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
442
443 let thread = graph.get_active_thread();
444 assert_eq!(thread.len(), 2);
445 assert_eq!(thread[0].id(), "msg1");
446 assert_eq!(thread[1].id(), "msg2");
447
448 assert!(!graph.checkout("non-existent"));
449 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
450 }
451
452 #[test]
453 fn test_compaction_boundary_filters_old_messages() {
454 let mut graph = MessageGraph::new();
455
456 let msg1 = create_user_message("msg1", None, "hello");
458 graph.add_message(msg1);
459 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
460 graph.add_message(msg2);
461 let msg3 = create_user_message("msg3", Some("msg2"), "tell me more");
462 graph.add_message(msg3);
463
464 let summary = create_assistant_message("summary", Some("msg3"), "Summary of conversation.");
466 graph.add_message(summary);
467 graph.mark_compaction_summary("summary".to_string());
468
469 let msg4 = create_user_message("msg4", Some("summary"), "new question");
471 graph.add_message(msg4);
472 let msg5 = create_assistant_message("msg5", Some("msg4"), "new answer");
473 graph.add_message(msg5);
474
475 let thread = graph.get_thread_messages();
476 let ids: Vec<&str> = thread.iter().map(|m| m.id()).collect();
477
478 assert_eq!(ids, vec!["summary", "msg4", "msg5"]);
480 }
481
482 #[test]
483 fn test_active_message_id_tracking() {
484 let mut graph = MessageGraph::new();
485
486 assert_eq!(graph.active_message_id, None);
487
488 let msg1 = create_user_message("msg1", None, "hello");
489 graph.add_message(msg1);
490 assert_eq!(graph.active_message_id, Some("msg1".to_string()));
491
492 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
493 graph.add_message(msg2);
494 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
495
496 let msg3 = create_user_message("msg3", Some("msg1"), "different question");
497 graph.add_message(msg3);
498 assert_eq!(graph.active_message_id, Some("msg3".to_string()));
499
500 let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
501 graph.add_message(msg4);
502 assert_eq!(graph.active_message_id, Some("msg4".to_string()));
503 }
504}