steer_core/app/
conversation.rs

1use crate::api::Client as ApiClient;
2use crate::config::model::ModelId;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::path::PathBuf;
7use std::str::FromStr;
8use std::time::{SystemTime, UNIX_EPOCH};
9use steer_tools::ToolCall;
10pub use steer_tools::result::ToolResult;
11use tracing::debug;
12
13use strum_macros::Display;
14use tokio_util::sync::CancellationToken;
15
16/// Result of a conversation compaction operation
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "result_type", rename_all = "snake_case")]
19pub enum CompactResult {
20    /// Compaction completed successfully with the summary
21    Success(String),
22    /// Compaction was cancelled by the user
23    Cancelled,
24    /// Not enough messages to compact
25    InsufficientMessages,
26}
27
28/// Response from executing an app command
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(tag = "response_type", rename_all = "snake_case")]
31pub enum CommandResponse {
32    /// Simple text response
33    Text(String),
34    /// Compact command response with structured result
35    Compact(CompactResult),
36}
37
38/// Types of app commands that can be executed
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(tag = "command_type", rename_all = "snake_case")]
41pub enum AppCommandType {
42    /// Model management - list or change models
43    Model { target: Option<String> },
44    /// Clear the conversation
45    Clear,
46    /// Compact the conversation
47    Compact,
48}
49
50impl AppCommandType {
51    /// Parse a command string into an AppCommandType
52    pub fn parse(input: &str) -> Result<Self, SlashCommandError> {
53        // Trim whitespace and remove leading slash if present
54        let command = input.trim();
55        let command = command.strip_prefix('/').unwrap_or(command);
56
57        // Split to get command name and args
58        let parts: Vec<&str> = command.split_whitespace().collect();
59        if parts.is_empty() {
60            return Err(SlashCommandError::InvalidFormat(
61                "Empty command".to_string(),
62            ));
63        }
64
65        match parts[0] {
66            "model" => {
67                let target = if parts.len() > 1 {
68                    Some(parts[1..].join(" "))
69                } else {
70                    None
71                };
72                Ok(AppCommandType::Model { target })
73            }
74            "clear" => Ok(AppCommandType::Clear),
75            "compact" => Ok(AppCommandType::Compact),
76            cmd => Err(SlashCommandError::UnknownCommand(cmd.to_string())),
77        }
78    }
79
80    /// Get the command string representation
81    pub fn as_command_str(&self) -> String {
82        match self {
83            AppCommandType::Model { target } => {
84                if let Some(model) = target {
85                    format!("model {model}")
86                } else {
87                    "model".to_string()
88                }
89            }
90            AppCommandType::Clear => "clear".to_string(),
91            AppCommandType::Compact => "compact".to_string(),
92        }
93    }
94}
95
96impl fmt::Display for AppCommandType {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        write!(f, "/{}", self.as_command_str())
99    }
100}
101
102impl FromStr for AppCommandType {
103    type Err = SlashCommandError;
104
105    fn from_str(s: &str) -> Result<Self, Self::Err> {
106        Self::parse(s)
107    }
108}
109
110/// Errors that can occur when parsing slash commands
111#[derive(Debug, thiserror::Error)]
112pub enum SlashCommandError {
113    #[error("Unknown command: {0}")]
114    UnknownCommand(String),
115    #[error("Invalid command format: {0}")]
116    InvalidFormat(String),
117}
118
119/// Role in the conversation
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Copy, Display)]
121pub enum Role {
122    User,
123    Assistant,
124    Tool,
125}
126
127/// Content that can be sent by a user
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129#[serde(tag = "type", rename_all = "snake_case")]
130pub enum UserContent {
131    Text {
132        text: String,
133    },
134    CommandExecution {
135        command: String,
136        stdout: String,
137        stderr: String,
138        exit_code: i32,
139    },
140    AppCommand {
141        command: AppCommandType,
142        response: Option<CommandResponse>,
143    },
144    // TODO: support attachments
145}
146
147impl UserContent {
148    pub fn format_command_execution_as_xml(
149        command: &str,
150        stdout: &str,
151        stderr: &str,
152        exit_code: i32,
153    ) -> String {
154        format!(
155            r#"<executed_command>
156    <command>{command}</command>
157    <stdout>{stdout}</stdout>
158    <stderr>{stderr}</stderr>
159    <exit_code>{exit_code}</exit_code>
160</executed_command>"#
161        )
162    }
163}
164
165/// Different types of thought content from AI models
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167#[serde(tag = "thought_type")]
168pub enum ThoughtContent {
169    /// Simple thought content (e.g., from Gemini)
170    #[serde(rename = "simple")]
171    Simple { text: String },
172    /// Claude-style thinking with signature
173    #[serde(rename = "signed")]
174    Signed { text: String, signature: String },
175    /// Claude-style redacted thinking
176    #[serde(rename = "redacted")]
177    Redacted { data: String },
178}
179
180impl ThoughtContent {
181    /// Extract displayable text from any thought type
182    pub fn display_text(&self) -> String {
183        match self {
184            ThoughtContent::Simple { text } => text.clone(),
185            ThoughtContent::Signed { text, .. } => text.clone(),
186            ThoughtContent::Redacted { .. } => "[Redacted Thinking]".to_string(),
187        }
188    }
189}
190
191/// Content that can be sent by an assistant
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193#[serde(tag = "type", rename_all = "snake_case")]
194pub enum AssistantContent {
195    Text { text: String },
196    ToolCall { tool_call: ToolCall },
197    Thought { thought: ThoughtContent },
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct Message {
202    pub timestamp: u64,
203    pub id: String,
204    pub parent_message_id: Option<String>,
205    pub data: MessageData,
206}
207
208/// A message in the conversation, with role-specific content
209#[derive(Debug, Clone, Serialize, Deserialize)]
210#[serde(tag = "role", rename_all = "lowercase")]
211pub enum MessageData {
212    User {
213        content: Vec<UserContent>,
214    },
215    Assistant {
216        content: Vec<AssistantContent>,
217    },
218    Tool {
219        tool_use_id: String,
220        result: ToolResult,
221    },
222}
223
224impl Message {
225    pub fn role(&self) -> Role {
226        match &self.data {
227            MessageData::User { .. } => Role::User,
228            MessageData::Assistant { .. } => Role::Assistant,
229            MessageData::Tool { .. } => Role::Tool,
230        }
231    }
232
233    pub fn id(&self) -> &str {
234        &self.id
235    }
236
237    pub fn timestamp(&self) -> u64 {
238        self.timestamp
239    }
240
241    pub fn parent_message_id(&self) -> Option<&str> {
242        self.parent_message_id.as_deref()
243    }
244
245    /// Helper to get current timestamp
246    pub fn current_timestamp() -> u64 {
247        SystemTime::now()
248            .duration_since(UNIX_EPOCH)
249            .expect("Time went backwards")
250            .as_secs()
251    }
252
253    /// Helper to generate unique IDs
254    pub fn generate_id(prefix: &str, _timestamp: u64) -> String {
255        use uuid::Uuid;
256        format!("{}_{}", prefix, Uuid::now_v7())
257    }
258
259    /// Extract text content from the message
260    pub fn extract_text(&self) -> String {
261        match &self.data {
262            MessageData::User { content } => content
263                .iter()
264                .filter_map(|c| match c {
265                    UserContent::Text { text } => Some(text.clone()),
266                    UserContent::CommandExecution { stdout, .. } => Some(stdout.clone()),
267                    UserContent::AppCommand { response, .. } => {
268                        response.as_ref().map(|r| match r {
269                            CommandResponse::Text(t) => t.clone(),
270                            CommandResponse::Compact(CompactResult::Success(s)) => s.clone(),
271                            _ => String::new(),
272                        })
273                    }
274                })
275                .collect::<Vec<_>>()
276                .join("\n"),
277            MessageData::Assistant { content } => content
278                .iter()
279                .filter_map(|c| match c {
280                    AssistantContent::Text { text } => Some(text.clone()),
281                    _ => None,
282                })
283                .collect::<Vec<_>>()
284                .join("\n"),
285            MessageData::Tool { result, .. } => result.llm_format(),
286        }
287    }
288
289    /// Get a string representation of the message content
290    pub fn content_string(&self) -> String {
291        match &self.data {
292            MessageData::User { content } => content
293                .iter()
294                .map(|c| match c {
295                    UserContent::Text { text } => text.clone(),
296                    UserContent::CommandExecution {
297                        command,
298                        stdout,
299                        stderr,
300                        exit_code,
301                    } => {
302                        let mut output = format!("$ {command}\n{stdout}");
303                        if *exit_code != 0 {
304                            output.push_str(&format!("\nExit code: {exit_code}"));
305                            if !stderr.is_empty() {
306                                output.push_str(&format!("\nError: {stderr}"));
307                            }
308                        }
309                        output
310                    }
311                    UserContent::AppCommand { command, response } => {
312                        if let Some(resp) = response {
313                            let text = match resp {
314                                CommandResponse::Text(msg) => msg.clone(),
315                                CommandResponse::Compact(result) => match result {
316                                    CompactResult::Success(summary) => summary.clone(),
317                                    CompactResult::Cancelled => {
318                                        "Compact command cancelled.".to_string()
319                                    }
320                                    CompactResult::InsufficientMessages => {
321                                        "Not enough messages to compact (minimum 10 required)."
322                                            .to_string()
323                                    }
324                                },
325                            };
326                            format!("/{}\n{}", command.as_command_str(), text)
327                        } else {
328                            format!("/{}", command.as_command_str())
329                        }
330                    }
331                })
332                .collect::<Vec<_>>()
333                .join("\n"),
334            MessageData::Assistant { content } => content
335                .iter()
336                .map(|c| match c {
337                    AssistantContent::Text { text } => text.clone(),
338                    AssistantContent::ToolCall { tool_call } => {
339                        format!("[Tool Call: {}]", tool_call.name)
340                    }
341                    AssistantContent::Thought { thought } => {
342                        format!("[Thought: {}]", thought.display_text())
343                    }
344                })
345                .collect::<Vec<_>>()
346                .join("\n"),
347            MessageData::Tool { result, .. } => {
348                // This is a simplified representation. The TUI will have a more detailed view.
349                let result_type = match result {
350                    ToolResult::Search(_) => "Search Result",
351                    ToolResult::FileList(_) => "File List",
352                    ToolResult::FileContent(_) => "File Content",
353                    ToolResult::Edit(_) => "Edit Result",
354                    ToolResult::Bash(_) => "Bash Result",
355                    ToolResult::Glob(_) => "Glob Result",
356                    ToolResult::TodoRead(_) => "Todo List",
357                    ToolResult::TodoWrite(_) => "Todo Update",
358                    ToolResult::Fetch(_) => "Fetch Result",
359                    ToolResult::Agent(_) => "Agent Result",
360                    ToolResult::External(_) => "External Tool Result",
361                    ToolResult::Error(_) => "Error",
362                };
363                format!("[Tool Result: {result_type}]")
364            }
365        }
366    }
367}
368
369const SUMMARY_PROMPT: &str = r#"Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
370This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context.
371
372Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
373
3741. Chronologically analyze each message and section of the conversation. For each section thoroughly identify:
375   - The user's explicit requests and intents
376   - Your approach to addressing the user's requests
377   - Key decisions, technical concepts and code patterns
378   - Specific details like file names, full code snippets, function signatures, file edits, etc
3792. Double-check for technical accuracy and completeness, addressing each required element thoroughly.
380
381Your summary should include the following sections:
382
3831. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
3842. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed.
3853. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
3864. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
3875. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
3886. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
3897. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests without confirming with the user first.
390                       If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
391
392Here's an example of how your output should be structured:
393
394<example>
395<analysis>
396[Your thought process, ensuring all points are covered thoroughly and accurately]
397</analysis>
398
399<summary>
4001. Primary Request and Intent:
401   [Detailed description]
402
4032. Key Technical Concepts:
404   - [Concept 1]
405   - [Concept 2]
406   - [...]
407
4083. Files and Code Sections:
409   - [File Name 1]
410      - [Summary of why this file is important]
411      - [Summary of the changes made to this file, if any]
412      - [Important Code Snippet]
413   - [File Name 2]
414      - [Important Code Snippet]
415   - [...]
416
4174. Problem Solving:
418   [Description of solved problems and ongoing troubleshooting]
419
4205. Pending Tasks:
421   - [Task 1]
422   - [Task 2]
423   - [...]
424
4256. Current Work:
426   [Precise description of current work]
427
4287. Optional Next Step:
429   [Optional Next step to take]
430
431</summary>
432</example>
433
434Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
435
436There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include:
437<example>
438## Compact Instructions
439When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them.
440</example>
441
442<example>
443# Summary instructions
444When you are using compact - please focus on test output and code changes. Include file reads verbatim.
445</example>"#;
446
447/// A conversation history
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Conversation {
450    pub messages: Vec<Message>,
451    pub working_directory: PathBuf,
452    /// The ID of the currently active message (head of the selected branch).
453    /// None means use last message semantics for backward compatibility.
454    pub active_message_id: Option<String>,
455}
456
457impl Default for Conversation {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463impl Conversation {
464    pub fn new() -> Self {
465        Self {
466            messages: Vec::new(),
467            working_directory: PathBuf::new(),
468            active_message_id: None,
469        }
470    }
471
472    pub fn add_message(&mut self, message: Message) {
473        self.active_message_id = Some(message.id().to_string());
474        self.messages.push(message);
475    }
476
477    pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
478        debug!(target: "conversation::add_message", "Adding message: {:?}", message_data);
479        self.messages.push(Message {
480            data: message_data,
481            id: Message::generate_id("", Message::current_timestamp()),
482            timestamp: Message::current_timestamp(),
483            parent_message_id: self.active_message_id.clone(),
484        });
485        self.active_message_id = Some(self.messages.last().unwrap().id().to_string());
486        self.messages.last().unwrap()
487    }
488
489    pub fn clear(&mut self) {
490        debug!(target:"conversation::clear", "Clearing conversation");
491        self.messages.clear();
492        self.active_message_id = None;
493    }
494
495    /// Find the tool name by its ID by searching through assistant messages with tool calls
496    pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
497        for message in self.messages.iter() {
498            if let MessageData::Assistant { content, .. } = &message.data {
499                for content_block in content {
500                    if let AssistantContent::ToolCall { tool_call } = content_block {
501                        if tool_call.id == tool_id {
502                            return Some(tool_call.name.clone());
503                        }
504                    }
505                }
506            }
507        }
508        None
509    }
510
511    /// Compact the conversation by summarizing older messages in the active thread
512    pub async fn compact(
513        &mut self,
514        api_client: &ApiClient,
515        model: ModelId,
516        token: CancellationToken,
517    ) -> crate::error::Result<CompactResult> {
518        // Get only the active thread
519        let thread = self.get_active_thread();
520
521        // Skip if we don't have enough messages to compact
522        if thread.len() < 10 {
523            return Ok(CompactResult::InsufficientMessages);
524        }
525
526        // Build prompt from active thread only
527        let mut prompt_messages: Vec<Message> = thread.into_iter().cloned().collect();
528        let last_msg_id = prompt_messages.last().map(|m| m.id().to_string());
529
530        prompt_messages.push(Message {
531            data: MessageData::User {
532                content: vec![UserContent::Text {
533                    text: SUMMARY_PROMPT.to_string(),
534                }],
535            },
536            timestamp: Message::current_timestamp(),
537            id: Message::generate_id("user", Message::current_timestamp()),
538            parent_message_id: last_msg_id.clone(),
539        });
540
541        let summary = tokio::select! {
542            biased;
543            result = api_client.complete(
544                &model,
545                prompt_messages,
546                None,
547                None,
548                None,
549                token.clone(),
550            ) => result.map_err(crate::error::Error::Api)?,
551            _ = token.cancelled() => {
552                return Ok(CompactResult::Cancelled);
553            }
554        };
555
556        let summary_text = summary.extract_text();
557
558        // Create a summary marker message (DO NOT clear messages)
559        let timestamp = Message::current_timestamp();
560        let summary_id = Message::generate_id("user", timestamp);
561
562        // Add the summary as a user message continuing the active thread
563        let summary_message = Message {
564            data: MessageData::User {
565                content: vec![UserContent::Text {
566                    text: format!("[COMPACTED SUMMARY]\n\n{summary_text}"),
567                }],
568            },
569            timestamp,
570            id: summary_id.clone(),
571            parent_message_id: None,
572        };
573
574        self.messages.push(summary_message);
575
576        // Update active_message_id to the summary marker
577        self.active_message_id = Some(summary_id);
578
579        Ok(CompactResult::Success(summary_text))
580    }
581
582    /// Edit a message non-destructively by creating a new branch.
583    /// Returns the ID of the new message if successful.
584    pub fn edit_message(
585        &mut self,
586        message_id: &str,
587        new_content: Vec<UserContent>,
588    ) -> Option<String> {
589        // Find the message to edit
590        let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
591
592        // Only allow editing user messages for now
593        if !matches!(&message_to_edit.data, MessageData::User { .. }) {
594            return None;
595        }
596
597        // Get the parent_message_id from the original message
598        let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
599
600        // Create the new message as a branch from the same parent
601        let new_message_id = Message::generate_id("user", Message::current_timestamp());
602        let edited_message = Message {
603            data: MessageData::User {
604                content: new_content,
605            },
606            timestamp: Message::current_timestamp(),
607            id: new_message_id.clone(),
608            parent_message_id: parent_id,
609        };
610
611        // Add the edited message (original remains in history)
612        self.messages.push(edited_message);
613
614        // Update active_message_id to the new branch head
615        self.active_message_id = Some(new_message_id.clone());
616
617        Some(new_message_id)
618    }
619
620    /// Switch to another branch by setting active_message_id
621    pub fn checkout(&mut self, message_id: &str) -> bool {
622        // Verify the message exists
623        if self.messages.iter().any(|m| m.id() == message_id) {
624            self.active_message_id = Some(message_id.to_string());
625            true
626        } else {
627            false
628        }
629    }
630
631    /// Get messages in the currently active thread
632    pub fn get_active_thread(&self) -> Vec<&Message> {
633        if self.messages.is_empty() {
634            return Vec::new();
635        }
636
637        // Determine the head of the active thread
638        let head_id = if let Some(ref active_id) = self.active_message_id {
639            // Use the explicitly set active message
640            active_id.as_str()
641        } else {
642            // Backward compatibility: use last message
643            self.messages.last().map(|m| m.id()).unwrap_or("")
644        };
645
646        // Find the head message
647        let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
648        if current_msg.is_none() {
649            // If active_message_id is invalid, fall back to last message
650            current_msg = self.messages.last();
651        }
652
653        let mut result = Vec::new();
654        let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
655
656        // Walk backwards from head to root
657        while let Some(msg) = current_msg {
658            result.push(msg);
659
660            // Find parent message using the id_map
661            current_msg = if let Some(parent_id) = msg.parent_message_id() {
662                id_map.get(parent_id).copied()
663            } else {
664                None
665            };
666        }
667
668        result.reverse();
669
670        debug!(
671            "Active thread: [{}]",
672            result
673                .iter()
674                .map(|msg| msg.id())
675                .collect::<Vec<_>>()
676                .join(", ")
677        );
678        result
679    }
680
681    /// Get messages in the current active branch by following parent links from the last message
682    /// This is a thin wrapper around get_active_thread for backward compatibility
683    pub fn get_thread_messages(&self) -> Vec<&Message> {
684        self.get_active_thread()
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use crate::app::conversation::{
691        AssistantContent, Conversation, Message, MessageData, UserContent,
692    };
693
694    /// Helper function to create a user message for testing
695    fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
696        Message {
697            data: MessageData::User {
698                content: vec![UserContent::Text {
699                    text: content.to_string(),
700                }],
701            },
702            timestamp: Message::current_timestamp(),
703            id: id.to_string(),
704            parent_message_id: parent_id.map(String::from),
705        }
706    }
707
708    /// Helper function to create an assistant message for testing
709    fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
710        Message {
711            data: MessageData::Assistant {
712                content: vec![AssistantContent::Text {
713                    text: content.to_string(),
714                }],
715            },
716            timestamp: Message::current_timestamp(),
717            id: id.to_string(),
718            parent_message_id: parent_id.map(String::from),
719        }
720    }
721
722    #[test]
723    fn test_editing_message_in_the_middle_of_conversation() {
724        let mut conversation = Conversation::new();
725
726        // 1. Build an initial conversation
727        let msg1 = create_user_message("msg1", None, "What is Rust?");
728        conversation.add_message(msg1.clone());
729
730        let msg2 =
731            create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
732        conversation.add_message(msg2.clone());
733
734        let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
735        conversation.add_message(msg3.clone());
736
737        let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
738        conversation.add_message(msg4.clone());
739
740        // 2. Edit the *first* user message
741        let edited_id = conversation
742            .edit_message(
743                "msg1",
744                vec![UserContent::Text {
745                    text: "What is Golang?".to_string(),
746                }],
747            )
748            .unwrap();
749
750        // 3. Check the state after editing
751        let messages_after_edit = conversation.get_thread_messages();
752        let message_ids_after_edit: Vec<&str> =
753            messages_after_edit.iter().map(|m| m.id()).collect();
754
755        assert_eq!(
756            message_ids_after_edit.len(),
757            1,
758            "Active thread should only show the edited message"
759        );
760        assert_eq!(message_ids_after_edit[0], edited_id.as_str());
761
762        // Verify original branch still exists in messages
763        assert!(conversation.messages.iter().any(|m| m.id() == "msg1"));
764        assert!(conversation.messages.iter().any(|m| m.id() == "msg2"));
765        assert!(conversation.messages.iter().any(|m| m.id() == "msg3"));
766        assert!(conversation.messages.iter().any(|m| m.id() == "msg4"));
767
768        // 4. Add a new message to the new branch of conversation
769        let msg5 = create_assistant_message(
770            "msg5",
771            Some(&edited_id),
772            "A systems programming language from Google.",
773        );
774        conversation.add_message(msg5.clone());
775
776        // 5. Check the final state of the conversation
777        let final_messages = conversation.get_thread_messages();
778        let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
779
780        assert_eq!(
781            final_messages.len(),
782            2,
783            "Should have the edited message and the new response."
784        );
785        assert_eq!(final_message_ids[0], edited_id.as_str());
786        assert_eq!(final_message_ids[1], "msg5");
787    }
788
789    #[test]
790    fn test_get_thread_messages_after_edit() {
791        let mut conversation = Conversation::new();
792
793        // 1. Initial conversation
794        let msg1 = create_user_message("msg1", None, "hello");
795        conversation.add_message(msg1.clone());
796
797        let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
798        conversation.add_message(msg2.clone());
799
800        // This is the message that will be "edited out"
801        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
802        conversation.add_message(msg3_original.clone());
803
804        // 2. Edit the last user message ("thanks")
805        let edited_id = conversation
806            .edit_message(
807                "msg3_original",
808                vec![UserContent::Text {
809                    text: "how are you".to_string(),
810                }],
811            )
812            .unwrap();
813
814        // 3. Add a new assistant message to the new branch
815        let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
816        conversation.add_message(msg4.clone());
817
818        // 4. Get messages for the current thread
819        let thread_messages = conversation.get_thread_messages();
820
821        // 5. Assertions
822        let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
823
824        // Should contain the root, the first assistant message, the *new* user message, and the final assistant message
825        assert_eq!(
826            thread_message_ids.len(),
827            4,
828            "Should have 4 messages in the current thread"
829        );
830        assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
831        assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
832        assert!(
833            thread_message_ids.contains(&edited_id.as_str()),
834            "Should contain the edited message"
835        );
836        assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
837
838        // Verify the original message still exists in the full message list
839        assert!(
840            conversation
841                .messages
842                .iter()
843                .any(|m| m.id() == "msg3_original"),
844            "Original message should still exist in conversation history"
845        );
846    }
847
848    #[test]
849    fn test_get_thread_messages_filters_other_branches() {
850        let mut conversation = Conversation::new();
851
852        // 1. Initial conversation: "hi"
853        let msg1 = create_user_message("msg1", None, "hi");
854        conversation.add_message(msg1.clone());
855
856        let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
857        conversation.add_message(msg2.clone());
858
859        // 2. User says "thanks" (this will be edited out)
860        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
861        conversation.add_message(msg3_original.clone());
862
863        let msg4_original =
864            create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
865        conversation.add_message(msg4_original.clone());
866
867        // 3. Edit the "thanks" message to "how are you"
868        let edited_id = conversation
869            .edit_message(
870                "msg3_original",
871                vec![UserContent::Text {
872                    text: "how are you".to_string(),
873                }],
874            )
875            .unwrap();
876
877        // 4. Add assistant response in the new branch
878        let msg4_new = create_assistant_message(
879            "msg4_new",
880            Some(&edited_id),
881            "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
882        );
883        conversation.add_message(msg4_new.clone());
884
885        // 5. User asks "what messages have I sent you?"
886        let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
887        conversation.add_message(msg5.clone());
888
889        // 6. Get messages for the current thread - this should NOT include "thanks"
890        let thread_messages = conversation.get_thread_messages();
891
892        // Extract the user messages
893        let user_messages: Vec<String> = thread_messages
894            .iter()
895            .filter(|m| matches!(m.data, MessageData::User { .. }))
896            .map(|m| m.extract_text())
897            .collect();
898
899        println!("User messages seen: {user_messages:?}");
900
901        // Assertions
902        assert_eq!(
903            user_messages.len(),
904            3,
905            "Should have exactly 3 user messages"
906        );
907        assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
908        assert_eq!(
909            user_messages[1], "how are you",
910            "Second message should be 'how are you' (edited)"
911        );
912        assert_eq!(
913            user_messages[2], "what messages have I sent you?",
914            "Third message should be the question"
915        );
916
917        // CRITICAL: Should NOT contain "thanks" from the other branch
918        assert!(
919            !user_messages.contains(&"thanks".to_string()),
920            "Should NOT contain 'thanks' from the non-active branch"
921        );
922
923        // But the original message should still exist in the full conversation
924        assert!(
925            conversation
926                .messages
927                .iter()
928                .any(|m| m.id() == "msg3_original"),
929            "Original 'thanks' message should still exist in conversation history"
930        );
931    }
932
933    #[test]
934    fn test_checkout_branch() {
935        let mut conversation = Conversation::new();
936
937        // Create initial conversation
938        let msg1 = create_user_message("msg1", None, "hello");
939        conversation.add_message(msg1.clone());
940
941        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
942        conversation.add_message(msg2.clone());
943
944        // Edit to create a branch
945        let edited_id = conversation
946            .edit_message(
947                "msg1",
948                vec![UserContent::Text {
949                    text: "goodbye".to_string(),
950                }],
951            )
952            .unwrap();
953
954        // Verify we're on the new branch
955        assert_eq!(conversation.active_message_id, Some(edited_id.clone()));
956        let thread = conversation.get_active_thread();
957        assert_eq!(thread.len(), 1);
958        assert_eq!(thread[0].id(), edited_id);
959
960        // Checkout the original branch
961        assert!(conversation.checkout("msg2"));
962        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
963
964        // Verify we're back on the original branch
965        let thread = conversation.get_active_thread();
966        assert_eq!(thread.len(), 2);
967        assert_eq!(thread[0].id(), "msg1");
968        assert_eq!(thread[1].id(), "msg2");
969
970        // Try to checkout non-existent message
971        assert!(!conversation.checkout("non-existent"));
972        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
973    }
974
975    #[test]
976    fn test_active_message_id_tracking() {
977        let mut conversation = Conversation::new();
978
979        // Initially no active message
980        assert_eq!(conversation.active_message_id, None);
981
982        // Add root message - should become active
983        let msg1 = create_user_message("msg1", None, "hello");
984        conversation.add_message(msg1);
985        assert_eq!(conversation.active_message_id, Some("msg1".to_string()));
986
987        // Add response - should update active
988        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
989        conversation.add_message(msg2);
990        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
991
992        // Add another branch from msg1
993        let msg3 = create_user_message("msg3", Some("msg1"), "different question");
994        conversation.add_message(msg3);
995        assert_eq!(conversation.active_message_id, Some("msg3".to_string()));
996
997        // Continue from active
998        let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
999        conversation.add_message(msg4);
1000        assert_eq!(conversation.active_message_id, Some("msg4".to_string()));
1001    }
1002}