Skip to main content

steer_core/app/conversation/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet};
3use tracing::debug;
4
5use super::message::{AssistantContent, Message, MessageData, UserContent};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageGraph {
9    pub messages: Vec<Message>,
10    pub active_message_id: Option<String>,
11    #[serde(default)]
12    pub compaction_summary_ids: HashSet<String>,
13}
14
15impl Default for MessageGraph {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl MessageGraph {
22    pub fn new() -> Self {
23        Self {
24            messages: Vec::new(),
25            active_message_id: None,
26            compaction_summary_ids: HashSet::new(),
27        }
28    }
29
30    pub fn add_message(&mut self, message: Message) {
31        self.active_message_id = Some(message.id().to_string());
32        self.messages.push(message);
33    }
34
35    pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
36        debug!(target: "message_graph::add_message", "Adding message: {:?}", message_data);
37        self.messages.push(Message {
38            data: message_data,
39            id: Message::generate_id("", Message::current_timestamp()),
40            timestamp: Message::current_timestamp(),
41            parent_message_id: self.active_message_id.clone(),
42        });
43        let last_index = self.messages.len().saturating_sub(1);
44        self.active_message_id = Some(self.messages[last_index].id().to_string());
45        &self.messages[last_index]
46    }
47
48    pub fn clear(&mut self) {
49        debug!(target:"message_graph::clear", "Clearing message graph");
50        self.messages.clear();
51        self.active_message_id = None;
52    }
53
54    pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
55        for message in &self.messages {
56            if let MessageData::Assistant { content, .. } = &message.data {
57                for content_block in content {
58                    if let AssistantContent::ToolCall { tool_call, .. } = content_block
59                        && tool_call.id == tool_id
60                    {
61                        return Some(tool_call.name.clone());
62                    }
63                }
64            }
65        }
66        None
67    }
68
69    pub fn edit_message(
70        &mut self,
71        message_id: &str,
72        new_content: Vec<UserContent>,
73    ) -> Option<String> {
74        let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
75
76        if !matches!(&message_to_edit.data, MessageData::User { .. }) {
77            return None;
78        }
79
80        let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
81
82        let new_message_id = Message::generate_id("user", Message::current_timestamp());
83        let edited_message = Message {
84            data: MessageData::User {
85                content: new_content,
86            },
87            timestamp: Message::current_timestamp(),
88            id: new_message_id.clone(),
89            parent_message_id: parent_id,
90        };
91
92        self.messages.push(edited_message);
93        self.active_message_id = Some(new_message_id.clone());
94
95        Some(new_message_id)
96    }
97
98    pub fn update_command_execution(
99        &mut self,
100        message_id: &str,
101        command: String,
102        stdout: String,
103        stderr: String,
104        exit_code: i32,
105    ) -> Option<Message> {
106        for message in &mut self.messages {
107            if message.id() != message_id {
108                continue;
109            }
110
111            if let MessageData::User { content } = &mut message.data {
112                *content = vec![UserContent::CommandExecution {
113                    command,
114                    stdout,
115                    stderr,
116                    exit_code,
117                }];
118                return Some(message.clone());
119            }
120
121            return None;
122        }
123
124        None
125    }
126
127    pub fn replace_message(&mut self, updated: Message) -> bool {
128        for message in &mut self.messages {
129            if message.id() == updated.id() {
130                *message = updated;
131                return true;
132            }
133        }
134
135        self.messages.push(updated);
136        false
137    }
138
139    pub fn checkout(&mut self, message_id: &str) -> bool {
140        if self.messages.iter().any(|m| m.id() == message_id) {
141            self.active_message_id = Some(message_id.to_string());
142            true
143        } else {
144            false
145        }
146    }
147
148    pub fn mark_compaction_summary(&mut self, id: String) {
149        self.compaction_summary_ids.insert(id);
150    }
151
152    pub fn get_active_thread(&self) -> Vec<&Message> {
153        if self.messages.is_empty() {
154            return Vec::new();
155        }
156
157        let head_id = if let Some(ref active_id) = self.active_message_id {
158            active_id.as_str()
159        } else {
160            self.messages.last().map_or("", |m| m.id())
161        };
162
163        let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
164        if current_msg.is_none() {
165            current_msg = self.messages.last();
166        }
167
168        let mut result = Vec::new();
169        let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
170
171        while let Some(msg) = current_msg {
172            result.push(msg);
173
174            // Stop walking at compaction boundary — the summary message is
175            // included but older (pre-compaction) messages are not.
176            if self.compaction_summary_ids.contains(msg.id()) {
177                break;
178            }
179
180            current_msg = if let Some(parent_id) = msg.parent_message_id() {
181                id_map.get(parent_id).copied()
182            } else {
183                None
184            };
185        }
186
187        result.reverse();
188
189        debug!(
190            "Active thread: [{}]",
191            result
192                .iter()
193                .map(|msg| msg.id())
194                .collect::<Vec<_>>()
195                .join(", ")
196        );
197        result
198    }
199
200    pub fn get_thread_messages(&self) -> Vec<&Message> {
201        self.get_active_thread()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
210        Message {
211            data: MessageData::User {
212                content: vec![UserContent::Text {
213                    text: content.to_string(),
214                }],
215            },
216            timestamp: Message::current_timestamp(),
217            id: id.to_string(),
218            parent_message_id: parent_id.map(String::from),
219        }
220    }
221
222    fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
223        Message {
224            data: MessageData::Assistant {
225                content: vec![AssistantContent::Text {
226                    text: content.to_string(),
227                }],
228            },
229            timestamp: Message::current_timestamp(),
230            id: id.to_string(),
231            parent_message_id: parent_id.map(String::from),
232        }
233    }
234
235    #[test]
236    fn test_editing_message_in_the_middle_of_conversation() {
237        let mut graph = MessageGraph::new();
238
239        let msg1 = create_user_message("msg1", None, "What is Rust?");
240        graph.add_message(msg1.clone());
241
242        let msg2 =
243            create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
244        graph.add_message(msg2.clone());
245
246        let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
247        graph.add_message(msg3.clone());
248
249        let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
250        graph.add_message(msg4.clone());
251
252        let edited_id = graph
253            .edit_message(
254                "msg1",
255                vec![UserContent::Text {
256                    text: "What is Golang?".to_string(),
257                }],
258            )
259            .unwrap();
260
261        let messages_after_edit = graph.get_thread_messages();
262        let message_ids_after_edit: Vec<&str> =
263            messages_after_edit.iter().map(|m| m.id()).collect();
264
265        assert_eq!(
266            message_ids_after_edit.len(),
267            1,
268            "Active thread should only show the edited message"
269        );
270        assert_eq!(message_ids_after_edit[0], edited_id.as_str());
271
272        assert!(graph.messages.iter().any(|m| m.id() == "msg1"));
273        assert!(graph.messages.iter().any(|m| m.id() == "msg2"));
274        assert!(graph.messages.iter().any(|m| m.id() == "msg3"));
275        assert!(graph.messages.iter().any(|m| m.id() == "msg4"));
276
277        let msg5 = create_assistant_message(
278            "msg5",
279            Some(&edited_id),
280            "A systems programming language from Google.",
281        );
282        graph.add_message(msg5.clone());
283
284        let final_messages = graph.get_thread_messages();
285        let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
286
287        assert_eq!(
288            final_messages.len(),
289            2,
290            "Should have the edited message and the new response."
291        );
292        assert_eq!(final_message_ids[0], edited_id.as_str());
293        assert_eq!(final_message_ids[1], "msg5");
294    }
295
296    #[test]
297    fn test_get_thread_messages_after_edit() {
298        let mut graph = MessageGraph::new();
299
300        let msg1 = create_user_message("msg1", None, "hello");
301        graph.add_message(msg1.clone());
302
303        let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
304        graph.add_message(msg2.clone());
305
306        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
307        graph.add_message(msg3_original.clone());
308
309        let edited_id = graph
310            .edit_message(
311                "msg3_original",
312                vec![UserContent::Text {
313                    text: "how are you".to_string(),
314                }],
315            )
316            .unwrap();
317
318        let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
319        graph.add_message(msg4.clone());
320
321        let thread_messages = graph.get_thread_messages();
322
323        let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
324
325        assert_eq!(
326            thread_message_ids.len(),
327            4,
328            "Should have 4 messages in the current thread"
329        );
330        assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
331        assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
332        assert!(
333            thread_message_ids.contains(&edited_id.as_str()),
334            "Should contain the edited message"
335        );
336        assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
337
338        assert!(
339            graph.messages.iter().any(|m| m.id() == "msg3_original"),
340            "Original message should still exist in message history"
341        );
342    }
343
344    #[test]
345    fn test_get_thread_messages_filters_other_branches() {
346        let mut graph = MessageGraph::new();
347
348        let msg1 = create_user_message("msg1", None, "hi");
349        graph.add_message(msg1.clone());
350
351        let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
352        graph.add_message(msg2.clone());
353
354        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
355        graph.add_message(msg3_original.clone());
356
357        let msg4_original =
358            create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
359        graph.add_message(msg4_original.clone());
360
361        let edited_id = graph
362            .edit_message(
363                "msg3_original",
364                vec![UserContent::Text {
365                    text: "how are you".to_string(),
366                }],
367            )
368            .unwrap();
369
370        let msg4_new = create_assistant_message(
371            "msg4_new",
372            Some(&edited_id),
373            "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
374        );
375        graph.add_message(msg4_new.clone());
376
377        let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
378        graph.add_message(msg5.clone());
379
380        let thread_messages = graph.get_thread_messages();
381
382        let user_messages: Vec<String> = thread_messages
383            .iter()
384            .filter(|m| matches!(m.data, MessageData::User { .. }))
385            .map(|m| m.extract_text())
386            .collect();
387
388        println!("User messages seen: {user_messages:?}");
389
390        assert_eq!(
391            user_messages.len(),
392            3,
393            "Should have exactly 3 user messages"
394        );
395        assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
396        assert_eq!(
397            user_messages[1], "how are you",
398            "Second message should be 'how are you' (edited)"
399        );
400        assert_eq!(
401            user_messages[2], "what messages have I sent you?",
402            "Third message should be the question"
403        );
404
405        assert!(
406            !user_messages.contains(&"thanks".to_string()),
407            "Should NOT contain 'thanks' from the non-active branch"
408        );
409
410        assert!(
411            graph.messages.iter().any(|m| m.id() == "msg3_original"),
412            "Original 'thanks' message should still exist in message history"
413        );
414    }
415
416    #[test]
417    fn test_checkout_branch() {
418        let mut graph = MessageGraph::new();
419
420        let msg1 = create_user_message("msg1", None, "hello");
421        graph.add_message(msg1.clone());
422
423        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
424        graph.add_message(msg2.clone());
425
426        let edited_id = graph
427            .edit_message(
428                "msg1",
429                vec![UserContent::Text {
430                    text: "goodbye".to_string(),
431                }],
432            )
433            .unwrap();
434
435        assert_eq!(graph.active_message_id, Some(edited_id.clone()));
436        let thread = graph.get_active_thread();
437        assert_eq!(thread.len(), 1);
438        assert_eq!(thread[0].id(), edited_id);
439
440        assert!(graph.checkout("msg2"));
441        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
442
443        let thread = graph.get_active_thread();
444        assert_eq!(thread.len(), 2);
445        assert_eq!(thread[0].id(), "msg1");
446        assert_eq!(thread[1].id(), "msg2");
447
448        assert!(!graph.checkout("non-existent"));
449        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
450    }
451
452    #[test]
453    fn test_compaction_boundary_filters_old_messages() {
454        let mut graph = MessageGraph::new();
455
456        // Pre-compaction messages
457        let msg1 = create_user_message("msg1", None, "hello");
458        graph.add_message(msg1);
459        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
460        graph.add_message(msg2);
461        let msg3 = create_user_message("msg3", Some("msg2"), "tell me more");
462        graph.add_message(msg3);
463
464        // Compaction summary (parent points to compacted head)
465        let summary = create_assistant_message("summary", Some("msg3"), "Summary of conversation.");
466        graph.add_message(summary);
467        graph.mark_compaction_summary("summary".to_string());
468
469        // Post-compaction messages
470        let msg4 = create_user_message("msg4", Some("summary"), "new question");
471        graph.add_message(msg4);
472        let msg5 = create_assistant_message("msg5", Some("msg4"), "new answer");
473        graph.add_message(msg5);
474
475        let thread = graph.get_thread_messages();
476        let ids: Vec<&str> = thread.iter().map(|m| m.id()).collect();
477
478        // Should include summary + post-compaction, but NOT pre-compaction
479        assert_eq!(ids, vec!["summary", "msg4", "msg5"]);
480    }
481
482    #[test]
483    fn test_active_message_id_tracking() {
484        let mut graph = MessageGraph::new();
485
486        assert_eq!(graph.active_message_id, None);
487
488        let msg1 = create_user_message("msg1", None, "hello");
489        graph.add_message(msg1);
490        assert_eq!(graph.active_message_id, Some("msg1".to_string()));
491
492        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
493        graph.add_message(msg2);
494        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
495
496        let msg3 = create_user_message("msg3", Some("msg1"), "different question");
497        graph.add_message(msg3);
498        assert_eq!(graph.active_message_id, Some("msg3".to_string()));
499
500        let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
501        graph.add_message(msg4);
502        assert_eq!(graph.active_message_id, Some("msg4".to_string()));
503    }
504}