1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tracing::debug;
4
5use super::message::{AssistantContent, Message, MessageData, UserContent};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageGraph {
9 pub messages: Vec<Message>,
10 pub active_message_id: Option<String>,
11}
12
13impl Default for MessageGraph {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl MessageGraph {
20 pub fn new() -> Self {
21 Self {
22 messages: Vec::new(),
23 active_message_id: None,
24 }
25 }
26
27 pub fn add_message(&mut self, message: Message) {
28 self.active_message_id = Some(message.id().to_string());
29 self.messages.push(message);
30 }
31
32 pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
33 debug!(target: "message_graph::add_message", "Adding message: {:?}", message_data);
34 self.messages.push(Message {
35 data: message_data,
36 id: Message::generate_id("", Message::current_timestamp()),
37 timestamp: Message::current_timestamp(),
38 parent_message_id: self.active_message_id.clone(),
39 });
40 let last_index = self.messages.len().saturating_sub(1);
41 self.active_message_id = Some(self.messages[last_index].id().to_string());
42 &self.messages[last_index]
43 }
44
45 pub fn clear(&mut self) {
46 debug!(target:"message_graph::clear", "Clearing message graph");
47 self.messages.clear();
48 self.active_message_id = None;
49 }
50
51 pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
52 for message in &self.messages {
53 if let MessageData::Assistant { content, .. } = &message.data {
54 for content_block in content {
55 if let AssistantContent::ToolCall { tool_call, .. } = content_block
56 && tool_call.id == tool_id
57 {
58 return Some(tool_call.name.clone());
59 }
60 }
61 }
62 }
63 None
64 }
65
66 pub fn edit_message(
67 &mut self,
68 message_id: &str,
69 new_content: Vec<UserContent>,
70 ) -> Option<String> {
71 let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
72
73 if !matches!(&message_to_edit.data, MessageData::User { .. }) {
74 return None;
75 }
76
77 let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
78
79 let new_message_id = Message::generate_id("user", Message::current_timestamp());
80 let edited_message = Message {
81 data: MessageData::User {
82 content: new_content,
83 },
84 timestamp: Message::current_timestamp(),
85 id: new_message_id.clone(),
86 parent_message_id: parent_id,
87 };
88
89 self.messages.push(edited_message);
90 self.active_message_id = Some(new_message_id.clone());
91
92 Some(new_message_id)
93 }
94
95 pub fn update_command_execution(
96 &mut self,
97 message_id: &str,
98 command: String,
99 stdout: String,
100 stderr: String,
101 exit_code: i32,
102 ) -> Option<Message> {
103 for message in &mut self.messages {
104 if message.id() != message_id {
105 continue;
106 }
107
108 if let MessageData::User { content } = &mut message.data {
109 *content = vec![UserContent::CommandExecution {
110 command,
111 stdout,
112 stderr,
113 exit_code,
114 }];
115 return Some(message.clone());
116 }
117
118 return None;
119 }
120
121 None
122 }
123
124 pub fn replace_message(&mut self, updated: Message) -> bool {
125 for message in &mut self.messages {
126 if message.id() == updated.id() {
127 *message = updated;
128 return true;
129 }
130 }
131
132 self.messages.push(updated);
133 false
134 }
135
136 pub fn checkout(&mut self, message_id: &str) -> bool {
137 if self.messages.iter().any(|m| m.id() == message_id) {
138 self.active_message_id = Some(message_id.to_string());
139 true
140 } else {
141 false
142 }
143 }
144
145 pub fn get_active_thread(&self) -> Vec<&Message> {
146 if self.messages.is_empty() {
147 return Vec::new();
148 }
149
150 let head_id = if let Some(ref active_id) = self.active_message_id {
151 active_id.as_str()
152 } else {
153 self.messages.last().map_or("", |m| m.id())
154 };
155
156 let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
157 if current_msg.is_none() {
158 current_msg = self.messages.last();
159 }
160
161 let mut result = Vec::new();
162 let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
163
164 while let Some(msg) = current_msg {
165 result.push(msg);
166
167 current_msg = if let Some(parent_id) = msg.parent_message_id() {
168 id_map.get(parent_id).copied()
169 } else {
170 None
171 };
172 }
173
174 result.reverse();
175
176 debug!(
177 "Active thread: [{}]",
178 result
179 .iter()
180 .map(|msg| msg.id())
181 .collect::<Vec<_>>()
182 .join(", ")
183 );
184 result
185 }
186
187 pub fn get_thread_messages(&self) -> Vec<&Message> {
188 self.get_active_thread()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
197 Message {
198 data: MessageData::User {
199 content: vec![UserContent::Text {
200 text: content.to_string(),
201 }],
202 },
203 timestamp: Message::current_timestamp(),
204 id: id.to_string(),
205 parent_message_id: parent_id.map(String::from),
206 }
207 }
208
209 fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
210 Message {
211 data: MessageData::Assistant {
212 content: vec![AssistantContent::Text {
213 text: content.to_string(),
214 }],
215 },
216 timestamp: Message::current_timestamp(),
217 id: id.to_string(),
218 parent_message_id: parent_id.map(String::from),
219 }
220 }
221
222 #[test]
223 fn test_editing_message_in_the_middle_of_conversation() {
224 let mut graph = MessageGraph::new();
225
226 let msg1 = create_user_message("msg1", None, "What is Rust?");
227 graph.add_message(msg1.clone());
228
229 let msg2 =
230 create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
231 graph.add_message(msg2.clone());
232
233 let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
234 graph.add_message(msg3.clone());
235
236 let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
237 graph.add_message(msg4.clone());
238
239 let edited_id = graph
240 .edit_message(
241 "msg1",
242 vec![UserContent::Text {
243 text: "What is Golang?".to_string(),
244 }],
245 )
246 .unwrap();
247
248 let messages_after_edit = graph.get_thread_messages();
249 let message_ids_after_edit: Vec<&str> =
250 messages_after_edit.iter().map(|m| m.id()).collect();
251
252 assert_eq!(
253 message_ids_after_edit.len(),
254 1,
255 "Active thread should only show the edited message"
256 );
257 assert_eq!(message_ids_after_edit[0], edited_id.as_str());
258
259 assert!(graph.messages.iter().any(|m| m.id() == "msg1"));
260 assert!(graph.messages.iter().any(|m| m.id() == "msg2"));
261 assert!(graph.messages.iter().any(|m| m.id() == "msg3"));
262 assert!(graph.messages.iter().any(|m| m.id() == "msg4"));
263
264 let msg5 = create_assistant_message(
265 "msg5",
266 Some(&edited_id),
267 "A systems programming language from Google.",
268 );
269 graph.add_message(msg5.clone());
270
271 let final_messages = graph.get_thread_messages();
272 let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
273
274 assert_eq!(
275 final_messages.len(),
276 2,
277 "Should have the edited message and the new response."
278 );
279 assert_eq!(final_message_ids[0], edited_id.as_str());
280 assert_eq!(final_message_ids[1], "msg5");
281 }
282
283 #[test]
284 fn test_get_thread_messages_after_edit() {
285 let mut graph = MessageGraph::new();
286
287 let msg1 = create_user_message("msg1", None, "hello");
288 graph.add_message(msg1.clone());
289
290 let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
291 graph.add_message(msg2.clone());
292
293 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
294 graph.add_message(msg3_original.clone());
295
296 let edited_id = graph
297 .edit_message(
298 "msg3_original",
299 vec![UserContent::Text {
300 text: "how are you".to_string(),
301 }],
302 )
303 .unwrap();
304
305 let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
306 graph.add_message(msg4.clone());
307
308 let thread_messages = graph.get_thread_messages();
309
310 let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
311
312 assert_eq!(
313 thread_message_ids.len(),
314 4,
315 "Should have 4 messages in the current thread"
316 );
317 assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
318 assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
319 assert!(
320 thread_message_ids.contains(&edited_id.as_str()),
321 "Should contain the edited message"
322 );
323 assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
324
325 assert!(
326 graph.messages.iter().any(|m| m.id() == "msg3_original"),
327 "Original message should still exist in message history"
328 );
329 }
330
331 #[test]
332 fn test_get_thread_messages_filters_other_branches() {
333 let mut graph = MessageGraph::new();
334
335 let msg1 = create_user_message("msg1", None, "hi");
336 graph.add_message(msg1.clone());
337
338 let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
339 graph.add_message(msg2.clone());
340
341 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
342 graph.add_message(msg3_original.clone());
343
344 let msg4_original =
345 create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
346 graph.add_message(msg4_original.clone());
347
348 let edited_id = graph
349 .edit_message(
350 "msg3_original",
351 vec![UserContent::Text {
352 text: "how are you".to_string(),
353 }],
354 )
355 .unwrap();
356
357 let msg4_new = create_assistant_message(
358 "msg4_new",
359 Some(&edited_id),
360 "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
361 );
362 graph.add_message(msg4_new.clone());
363
364 let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
365 graph.add_message(msg5.clone());
366
367 let thread_messages = graph.get_thread_messages();
368
369 let user_messages: Vec<String> = thread_messages
370 .iter()
371 .filter(|m| matches!(m.data, MessageData::User { .. }))
372 .map(|m| m.extract_text())
373 .collect();
374
375 println!("User messages seen: {user_messages:?}");
376
377 assert_eq!(
378 user_messages.len(),
379 3,
380 "Should have exactly 3 user messages"
381 );
382 assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
383 assert_eq!(
384 user_messages[1], "how are you",
385 "Second message should be 'how are you' (edited)"
386 );
387 assert_eq!(
388 user_messages[2], "what messages have I sent you?",
389 "Third message should be the question"
390 );
391
392 assert!(
393 !user_messages.contains(&"thanks".to_string()),
394 "Should NOT contain 'thanks' from the non-active branch"
395 );
396
397 assert!(
398 graph.messages.iter().any(|m| m.id() == "msg3_original"),
399 "Original 'thanks' message should still exist in message history"
400 );
401 }
402
403 #[test]
404 fn test_checkout_branch() {
405 let mut graph = MessageGraph::new();
406
407 let msg1 = create_user_message("msg1", None, "hello");
408 graph.add_message(msg1.clone());
409
410 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
411 graph.add_message(msg2.clone());
412
413 let edited_id = graph
414 .edit_message(
415 "msg1",
416 vec![UserContent::Text {
417 text: "goodbye".to_string(),
418 }],
419 )
420 .unwrap();
421
422 assert_eq!(graph.active_message_id, Some(edited_id.clone()));
423 let thread = graph.get_active_thread();
424 assert_eq!(thread.len(), 1);
425 assert_eq!(thread[0].id(), edited_id);
426
427 assert!(graph.checkout("msg2"));
428 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
429
430 let thread = graph.get_active_thread();
431 assert_eq!(thread.len(), 2);
432 assert_eq!(thread[0].id(), "msg1");
433 assert_eq!(thread[1].id(), "msg2");
434
435 assert!(!graph.checkout("non-existent"));
436 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
437 }
438
439 #[test]
440 fn test_active_message_id_tracking() {
441 let mut graph = MessageGraph::new();
442
443 assert_eq!(graph.active_message_id, None);
444
445 let msg1 = create_user_message("msg1", None, "hello");
446 graph.add_message(msg1);
447 assert_eq!(graph.active_message_id, Some("msg1".to_string()));
448
449 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
450 graph.add_message(msg2);
451 assert_eq!(graph.active_message_id, Some("msg2".to_string()));
452
453 let msg3 = create_user_message("msg3", Some("msg1"), "different question");
454 graph.add_message(msg3);
455 assert_eq!(graph.active_message_id, Some("msg3".to_string()));
456
457 let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
458 graph.add_message(msg4);
459 assert_eq!(graph.active_message_id, Some("msg4".to_string()));
460 }
461}