Skip to main content

steer_core/app/conversation/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tracing::debug;
4
5use super::message::{AssistantContent, Message, MessageData, UserContent};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MessageGraph {
9    pub messages: Vec<Message>,
10    pub active_message_id: Option<String>,
11}
12
13impl Default for MessageGraph {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl MessageGraph {
20    pub fn new() -> Self {
21        Self {
22            messages: Vec::new(),
23            active_message_id: None,
24        }
25    }
26
27    pub fn add_message(&mut self, message: Message) {
28        self.active_message_id = Some(message.id().to_string());
29        self.messages.push(message);
30    }
31
32    pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
33        debug!(target: "message_graph::add_message", "Adding message: {:?}", message_data);
34        self.messages.push(Message {
35            data: message_data,
36            id: Message::generate_id("", Message::current_timestamp()),
37            timestamp: Message::current_timestamp(),
38            parent_message_id: self.active_message_id.clone(),
39        });
40        let last_index = self.messages.len().saturating_sub(1);
41        self.active_message_id = Some(self.messages[last_index].id().to_string());
42        &self.messages[last_index]
43    }
44
45    pub fn clear(&mut self) {
46        debug!(target:"message_graph::clear", "Clearing message graph");
47        self.messages.clear();
48        self.active_message_id = None;
49    }
50
51    pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
52        for message in &self.messages {
53            if let MessageData::Assistant { content, .. } = &message.data {
54                for content_block in content {
55                    if let AssistantContent::ToolCall { tool_call, .. } = content_block
56                        && tool_call.id == tool_id
57                    {
58                        return Some(tool_call.name.clone());
59                    }
60                }
61            }
62        }
63        None
64    }
65
66    pub fn edit_message(
67        &mut self,
68        message_id: &str,
69        new_content: Vec<UserContent>,
70    ) -> Option<String> {
71        let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
72
73        if !matches!(&message_to_edit.data, MessageData::User { .. }) {
74            return None;
75        }
76
77        let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
78
79        let new_message_id = Message::generate_id("user", Message::current_timestamp());
80        let edited_message = Message {
81            data: MessageData::User {
82                content: new_content,
83            },
84            timestamp: Message::current_timestamp(),
85            id: new_message_id.clone(),
86            parent_message_id: parent_id,
87        };
88
89        self.messages.push(edited_message);
90        self.active_message_id = Some(new_message_id.clone());
91
92        Some(new_message_id)
93    }
94
95    pub fn update_command_execution(
96        &mut self,
97        message_id: &str,
98        command: String,
99        stdout: String,
100        stderr: String,
101        exit_code: i32,
102    ) -> Option<Message> {
103        for message in &mut self.messages {
104            if message.id() != message_id {
105                continue;
106            }
107
108            if let MessageData::User { content } = &mut message.data {
109                *content = vec![UserContent::CommandExecution {
110                    command,
111                    stdout,
112                    stderr,
113                    exit_code,
114                }];
115                return Some(message.clone());
116            }
117
118            return None;
119        }
120
121        None
122    }
123
124    pub fn replace_message(&mut self, updated: Message) -> bool {
125        for message in &mut self.messages {
126            if message.id() == updated.id() {
127                *message = updated;
128                return true;
129            }
130        }
131
132        self.messages.push(updated);
133        false
134    }
135
136    pub fn checkout(&mut self, message_id: &str) -> bool {
137        if self.messages.iter().any(|m| m.id() == message_id) {
138            self.active_message_id = Some(message_id.to_string());
139            true
140        } else {
141            false
142        }
143    }
144
145    pub fn get_active_thread(&self) -> Vec<&Message> {
146        if self.messages.is_empty() {
147            return Vec::new();
148        }
149
150        let head_id = if let Some(ref active_id) = self.active_message_id {
151            active_id.as_str()
152        } else {
153            self.messages.last().map_or("", |m| m.id())
154        };
155
156        let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
157        if current_msg.is_none() {
158            current_msg = self.messages.last();
159        }
160
161        let mut result = Vec::new();
162        let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
163
164        while let Some(msg) = current_msg {
165            result.push(msg);
166
167            current_msg = if let Some(parent_id) = msg.parent_message_id() {
168                id_map.get(parent_id).copied()
169            } else {
170                None
171            };
172        }
173
174        result.reverse();
175
176        debug!(
177            "Active thread: [{}]",
178            result
179                .iter()
180                .map(|msg| msg.id())
181                .collect::<Vec<_>>()
182                .join(", ")
183        );
184        result
185    }
186
187    pub fn get_thread_messages(&self) -> Vec<&Message> {
188        self.get_active_thread()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
197        Message {
198            data: MessageData::User {
199                content: vec![UserContent::Text {
200                    text: content.to_string(),
201                }],
202            },
203            timestamp: Message::current_timestamp(),
204            id: id.to_string(),
205            parent_message_id: parent_id.map(String::from),
206        }
207    }
208
209    fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
210        Message {
211            data: MessageData::Assistant {
212                content: vec![AssistantContent::Text {
213                    text: content.to_string(),
214                }],
215            },
216            timestamp: Message::current_timestamp(),
217            id: id.to_string(),
218            parent_message_id: parent_id.map(String::from),
219        }
220    }
221
222    #[test]
223    fn test_editing_message_in_the_middle_of_conversation() {
224        let mut graph = MessageGraph::new();
225
226        let msg1 = create_user_message("msg1", None, "What is Rust?");
227        graph.add_message(msg1.clone());
228
229        let msg2 =
230            create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
231        graph.add_message(msg2.clone());
232
233        let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
234        graph.add_message(msg3.clone());
235
236        let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
237        graph.add_message(msg4.clone());
238
239        let edited_id = graph
240            .edit_message(
241                "msg1",
242                vec![UserContent::Text {
243                    text: "What is Golang?".to_string(),
244                }],
245            )
246            .unwrap();
247
248        let messages_after_edit = graph.get_thread_messages();
249        let message_ids_after_edit: Vec<&str> =
250            messages_after_edit.iter().map(|m| m.id()).collect();
251
252        assert_eq!(
253            message_ids_after_edit.len(),
254            1,
255            "Active thread should only show the edited message"
256        );
257        assert_eq!(message_ids_after_edit[0], edited_id.as_str());
258
259        assert!(graph.messages.iter().any(|m| m.id() == "msg1"));
260        assert!(graph.messages.iter().any(|m| m.id() == "msg2"));
261        assert!(graph.messages.iter().any(|m| m.id() == "msg3"));
262        assert!(graph.messages.iter().any(|m| m.id() == "msg4"));
263
264        let msg5 = create_assistant_message(
265            "msg5",
266            Some(&edited_id),
267            "A systems programming language from Google.",
268        );
269        graph.add_message(msg5.clone());
270
271        let final_messages = graph.get_thread_messages();
272        let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
273
274        assert_eq!(
275            final_messages.len(),
276            2,
277            "Should have the edited message and the new response."
278        );
279        assert_eq!(final_message_ids[0], edited_id.as_str());
280        assert_eq!(final_message_ids[1], "msg5");
281    }
282
283    #[test]
284    fn test_get_thread_messages_after_edit() {
285        let mut graph = MessageGraph::new();
286
287        let msg1 = create_user_message("msg1", None, "hello");
288        graph.add_message(msg1.clone());
289
290        let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
291        graph.add_message(msg2.clone());
292
293        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
294        graph.add_message(msg3_original.clone());
295
296        let edited_id = graph
297            .edit_message(
298                "msg3_original",
299                vec![UserContent::Text {
300                    text: "how are you".to_string(),
301                }],
302            )
303            .unwrap();
304
305        let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
306        graph.add_message(msg4.clone());
307
308        let thread_messages = graph.get_thread_messages();
309
310        let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
311
312        assert_eq!(
313            thread_message_ids.len(),
314            4,
315            "Should have 4 messages in the current thread"
316        );
317        assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
318        assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
319        assert!(
320            thread_message_ids.contains(&edited_id.as_str()),
321            "Should contain the edited message"
322        );
323        assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
324
325        assert!(
326            graph.messages.iter().any(|m| m.id() == "msg3_original"),
327            "Original message should still exist in message history"
328        );
329    }
330
331    #[test]
332    fn test_get_thread_messages_filters_other_branches() {
333        let mut graph = MessageGraph::new();
334
335        let msg1 = create_user_message("msg1", None, "hi");
336        graph.add_message(msg1.clone());
337
338        let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
339        graph.add_message(msg2.clone());
340
341        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
342        graph.add_message(msg3_original.clone());
343
344        let msg4_original =
345            create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
346        graph.add_message(msg4_original.clone());
347
348        let edited_id = graph
349            .edit_message(
350                "msg3_original",
351                vec![UserContent::Text {
352                    text: "how are you".to_string(),
353                }],
354            )
355            .unwrap();
356
357        let msg4_new = create_assistant_message(
358            "msg4_new",
359            Some(&edited_id),
360            "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
361        );
362        graph.add_message(msg4_new.clone());
363
364        let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
365        graph.add_message(msg5.clone());
366
367        let thread_messages = graph.get_thread_messages();
368
369        let user_messages: Vec<String> = thread_messages
370            .iter()
371            .filter(|m| matches!(m.data, MessageData::User { .. }))
372            .map(|m| m.extract_text())
373            .collect();
374
375        println!("User messages seen: {user_messages:?}");
376
377        assert_eq!(
378            user_messages.len(),
379            3,
380            "Should have exactly 3 user messages"
381        );
382        assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
383        assert_eq!(
384            user_messages[1], "how are you",
385            "Second message should be 'how are you' (edited)"
386        );
387        assert_eq!(
388            user_messages[2], "what messages have I sent you?",
389            "Third message should be the question"
390        );
391
392        assert!(
393            !user_messages.contains(&"thanks".to_string()),
394            "Should NOT contain 'thanks' from the non-active branch"
395        );
396
397        assert!(
398            graph.messages.iter().any(|m| m.id() == "msg3_original"),
399            "Original 'thanks' message should still exist in message history"
400        );
401    }
402
403    #[test]
404    fn test_checkout_branch() {
405        let mut graph = MessageGraph::new();
406
407        let msg1 = create_user_message("msg1", None, "hello");
408        graph.add_message(msg1.clone());
409
410        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
411        graph.add_message(msg2.clone());
412
413        let edited_id = graph
414            .edit_message(
415                "msg1",
416                vec![UserContent::Text {
417                    text: "goodbye".to_string(),
418                }],
419            )
420            .unwrap();
421
422        assert_eq!(graph.active_message_id, Some(edited_id.clone()));
423        let thread = graph.get_active_thread();
424        assert_eq!(thread.len(), 1);
425        assert_eq!(thread[0].id(), edited_id);
426
427        assert!(graph.checkout("msg2"));
428        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
429
430        let thread = graph.get_active_thread();
431        assert_eq!(thread.len(), 2);
432        assert_eq!(thread[0].id(), "msg1");
433        assert_eq!(thread[1].id(), "msg2");
434
435        assert!(!graph.checkout("non-existent"));
436        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
437    }
438
439    #[test]
440    fn test_active_message_id_tracking() {
441        let mut graph = MessageGraph::new();
442
443        assert_eq!(graph.active_message_id, None);
444
445        let msg1 = create_user_message("msg1", None, "hello");
446        graph.add_message(msg1);
447        assert_eq!(graph.active_message_id, Some("msg1".to_string()));
448
449        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
450        graph.add_message(msg2);
451        assert_eq!(graph.active_message_id, Some("msg2".to_string()));
452
453        let msg3 = create_user_message("msg3", Some("msg1"), "different question");
454        graph.add_message(msg3);
455        assert_eq!(graph.active_message_id, Some("msg3".to_string()));
456
457        let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
458        graph.add_message(msg4);
459        assert_eq!(graph.active_message_id, Some("msg4".to_string()));
460    }
461}