steer_core/app/
conversation.rs

1use crate::api::Client as ApiClient;
2use crate::api::Model;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::path::PathBuf;
7use std::str::FromStr;
8use std::time::{SystemTime, UNIX_EPOCH};
9use steer_tools::ToolCall;
10pub use steer_tools::result::ToolResult;
11use tracing::debug;
12
13use strum_macros::Display;
14use tokio_util::sync::CancellationToken;
15
16/// Result of a conversation compaction operation
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "result_type", rename_all = "snake_case")]
19pub enum CompactResult {
20    /// Compaction completed successfully with the summary
21    Success(String),
22    /// Compaction was cancelled by the user
23    Cancelled,
24    /// Not enough messages to compact
25    InsufficientMessages,
26}
27
28/// Response from executing an app command
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(tag = "response_type", rename_all = "snake_case")]
31pub enum CommandResponse {
32    /// Simple text response
33    Text(String),
34    /// Compact command response with structured result
35    Compact(CompactResult),
36}
37
38/// Types of app commands that can be executed
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(tag = "command_type", rename_all = "snake_case")]
41pub enum AppCommandType {
42    /// Model management - list or change models
43    Model { target: Option<String> },
44    /// Clear the conversation
45    Clear,
46    /// Compact the conversation
47    Compact,
48}
49
50impl AppCommandType {
51    /// Parse a command string into an AppCommandType
52    pub fn parse(input: &str) -> Result<Self, SlashCommandError> {
53        // Trim whitespace and remove leading slash if present
54        let command = input.trim();
55        let command = command.strip_prefix('/').unwrap_or(command);
56
57        // Split to get command name and args
58        let parts: Vec<&str> = command.split_whitespace().collect();
59        if parts.is_empty() {
60            return Err(SlashCommandError::InvalidFormat(
61                "Empty command".to_string(),
62            ));
63        }
64
65        match parts[0] {
66            "model" => {
67                let target = if parts.len() > 1 {
68                    Some(parts[1..].join(" "))
69                } else {
70                    None
71                };
72                Ok(AppCommandType::Model { target })
73            }
74            "clear" => Ok(AppCommandType::Clear),
75            "compact" => Ok(AppCommandType::Compact),
76            cmd => Err(SlashCommandError::UnknownCommand(cmd.to_string())),
77        }
78    }
79
80    /// Get the command string representation
81    pub fn as_command_str(&self) -> String {
82        match self {
83            AppCommandType::Model { target } => {
84                if let Some(model) = target {
85                    format!("model {model}")
86                } else {
87                    "model".to_string()
88                }
89            }
90            AppCommandType::Clear => "clear".to_string(),
91            AppCommandType::Compact => "compact".to_string(),
92        }
93    }
94}
95
96impl fmt::Display for AppCommandType {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        write!(f, "/{}", self.as_command_str())
99    }
100}
101
102impl FromStr for AppCommandType {
103    type Err = SlashCommandError;
104
105    fn from_str(s: &str) -> Result<Self, Self::Err> {
106        Self::parse(s)
107    }
108}
109
110/// Errors that can occur when parsing slash commands
111#[derive(Debug, thiserror::Error)]
112pub enum SlashCommandError {
113    #[error("Unknown command: {0}")]
114    UnknownCommand(String),
115    #[error("Invalid command format: {0}")]
116    InvalidFormat(String),
117}
118
119/// Role in the conversation
120#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Copy, Display)]
121pub enum Role {
122    User,
123    Assistant,
124    Tool,
125}
126
127/// Content that can be sent by a user
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129#[serde(tag = "type", rename_all = "snake_case")]
130pub enum UserContent {
131    Text {
132        text: String,
133    },
134    CommandExecution {
135        command: String,
136        stdout: String,
137        stderr: String,
138        exit_code: i32,
139    },
140    AppCommand {
141        command: AppCommandType,
142        response: Option<CommandResponse>,
143    },
144    // TODO: support attachments
145}
146
147impl UserContent {
148    pub fn format_command_execution_as_xml(
149        command: &str,
150        stdout: &str,
151        stderr: &str,
152        exit_code: i32,
153    ) -> String {
154        format!(
155            r#"<executed_command>
156    <command>{command}</command>
157    <stdout>{stdout}</stdout>
158    <stderr>{stderr}</stderr>
159    <exit_code>{exit_code}</exit_code>
160</executed_command>"#
161        )
162    }
163}
164
165/// Different types of thought content from AI models
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167#[serde(tag = "thought_type")]
168pub enum ThoughtContent {
169    /// Simple thought content (e.g., from Gemini)
170    #[serde(rename = "simple")]
171    Simple { text: String },
172    /// Claude-style thinking with signature
173    #[serde(rename = "signed")]
174    Signed { text: String, signature: String },
175    /// Claude-style redacted thinking
176    #[serde(rename = "redacted")]
177    Redacted { data: String },
178}
179
180impl ThoughtContent {
181    /// Extract displayable text from any thought type
182    pub fn display_text(&self) -> String {
183        match self {
184            ThoughtContent::Simple { text } => text.clone(),
185            ThoughtContent::Signed { text, .. } => text.clone(),
186            ThoughtContent::Redacted { .. } => "[Redacted Thinking]".to_string(),
187        }
188    }
189}
190
191/// Content that can be sent by an assistant
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193#[serde(tag = "type", rename_all = "snake_case")]
194pub enum AssistantContent {
195    Text { text: String },
196    ToolCall { tool_call: ToolCall },
197    Thought { thought: ThoughtContent },
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct Message {
202    pub timestamp: u64,
203    pub id: String,
204    pub parent_message_id: Option<String>,
205    pub data: MessageData,
206}
207
208/// A message in the conversation, with role-specific content
209#[derive(Debug, Clone, Serialize, Deserialize)]
210#[serde(tag = "role", rename_all = "lowercase")]
211pub enum MessageData {
212    User {
213        content: Vec<UserContent>,
214    },
215    Assistant {
216        content: Vec<AssistantContent>,
217    },
218    Tool {
219        tool_use_id: String,
220        result: ToolResult,
221    },
222}
223
224impl Message {
225    pub fn role(&self) -> Role {
226        match &self.data {
227            MessageData::User { .. } => Role::User,
228            MessageData::Assistant { .. } => Role::Assistant,
229            MessageData::Tool { .. } => Role::Tool,
230        }
231    }
232
233    pub fn id(&self) -> &str {
234        &self.id
235    }
236
237    pub fn timestamp(&self) -> u64 {
238        self.timestamp
239    }
240
241    pub fn parent_message_id(&self) -> Option<&str> {
242        self.parent_message_id.as_deref()
243    }
244
245    /// Helper to get current timestamp
246    pub fn current_timestamp() -> u64 {
247        SystemTime::now()
248            .duration_since(UNIX_EPOCH)
249            .expect("Time went backwards")
250            .as_secs()
251    }
252
253    /// Helper to generate unique IDs
254    pub fn generate_id(prefix: &str, _timestamp: u64) -> String {
255        use uuid::Uuid;
256        format!("{}_{}", prefix, Uuid::now_v7())
257    }
258
259    /// Extract text content from the message
260    pub fn extract_text(&self) -> String {
261        match &self.data {
262            MessageData::User { content } => content
263                .iter()
264                .filter_map(|c| match c {
265                    UserContent::Text { text } => Some(text.clone()),
266                    UserContent::CommandExecution { stdout, .. } => Some(stdout.clone()),
267                    UserContent::AppCommand { response, .. } => {
268                        response.as_ref().map(|r| match r {
269                            CommandResponse::Text(t) => t.clone(),
270                            CommandResponse::Compact(CompactResult::Success(s)) => s.clone(),
271                            _ => String::new(),
272                        })
273                    }
274                })
275                .collect::<Vec<_>>()
276                .join("\n"),
277            MessageData::Assistant { content } => content
278                .iter()
279                .filter_map(|c| match c {
280                    AssistantContent::Text { text } => Some(text.clone()),
281                    _ => None,
282                })
283                .collect::<Vec<_>>()
284                .join("\n"),
285            MessageData::Tool { result, .. } => result.llm_format(),
286        }
287    }
288
289    /// Get a string representation of the message content
290    pub fn content_string(&self) -> String {
291        match &self.data {
292            MessageData::User { content } => content
293                .iter()
294                .map(|c| match c {
295                    UserContent::Text { text } => text.clone(),
296                    UserContent::CommandExecution {
297                        command,
298                        stdout,
299                        stderr,
300                        exit_code,
301                    } => {
302                        let mut output = format!("$ {command}\n{stdout}");
303                        if *exit_code != 0 {
304                            output.push_str(&format!("\nExit code: {exit_code}"));
305                            if !stderr.is_empty() {
306                                output.push_str(&format!("\nError: {stderr}"));
307                            }
308                        }
309                        output
310                    }
311                    UserContent::AppCommand { command, response } => {
312                        if let Some(resp) = response {
313                            let text = match resp {
314                                CommandResponse::Text(msg) => msg.clone(),
315                                CommandResponse::Compact(result) => match result {
316                                    CompactResult::Success(summary) => summary.clone(),
317                                    CompactResult::Cancelled => {
318                                        "Compact command cancelled.".to_string()
319                                    }
320                                    CompactResult::InsufficientMessages => {
321                                        "Not enough messages to compact (minimum 10 required)."
322                                            .to_string()
323                                    }
324                                },
325                            };
326                            format!("/{}\n{}", command.as_command_str(), text)
327                        } else {
328                            format!("/{}", command.as_command_str())
329                        }
330                    }
331                })
332                .collect::<Vec<_>>()
333                .join("\n"),
334            MessageData::Assistant { content } => content
335                .iter()
336                .map(|c| match c {
337                    AssistantContent::Text { text } => text.clone(),
338                    AssistantContent::ToolCall { tool_call } => {
339                        format!("[Tool Call: {}]", tool_call.name)
340                    }
341                    AssistantContent::Thought { thought } => {
342                        format!("[Thought: {}]", thought.display_text())
343                    }
344                })
345                .collect::<Vec<_>>()
346                .join("\n"),
347            MessageData::Tool { result, .. } => {
348                // This is a simplified representation. The TUI will have a more detailed view.
349                let result_type = match result {
350                    ToolResult::Search(_) => "Search Result",
351                    ToolResult::FileList(_) => "File List",
352                    ToolResult::FileContent(_) => "File Content",
353                    ToolResult::Edit(_) => "Edit Result",
354                    ToolResult::Bash(_) => "Bash Result",
355                    ToolResult::Glob(_) => "Glob Result",
356                    ToolResult::TodoRead(_) => "Todo List",
357                    ToolResult::TodoWrite(_) => "Todo Update",
358                    ToolResult::Fetch(_) => "Fetch Result",
359                    ToolResult::Agent(_) => "Agent Result",
360                    ToolResult::External(_) => "External Tool Result",
361                    ToolResult::Error(_) => "Error",
362                };
363                format!("[Tool Result: {result_type}]")
364            }
365        }
366    }
367}
368
369const SUMMARY_PROMPT: &str = r#"Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
370This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context.
371
372Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
373
3741. Chronologically analyze each message and section of the conversation. For each section thoroughly identify:
375   - The user's explicit requests and intents
376   - Your approach to addressing the user's requests
377   - Key decisions, technical concepts and code patterns
378   - Specific details like file names, full code snippets, function signatures, file edits, etc
3792. Double-check for technical accuracy and completeness, addressing each required element thoroughly.
380
381Your summary should include the following sections:
382
3831. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
3842. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed.
3853. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
3864. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
3875. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
3886. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
3897. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests without confirming with the user first.
390                       If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
391
392Here's an example of how your output should be structured:
393
394<example>
395<analysis>
396[Your thought process, ensuring all points are covered thoroughly and accurately]
397</analysis>
398
399<summary>
4001. Primary Request and Intent:
401   [Detailed description]
402
4032. Key Technical Concepts:
404   - [Concept 1]
405   - [Concept 2]
406   - [...]
407
4083. Files and Code Sections:
409   - [File Name 1]
410      - [Summary of why this file is important]
411      - [Summary of the changes made to this file, if any]
412      - [Important Code Snippet]
413   - [File Name 2]
414      - [Important Code Snippet]
415   - [...]
416
4174. Problem Solving:
418   [Description of solved problems and ongoing troubleshooting]
419
4205. Pending Tasks:
421   - [Task 1]
422   - [Task 2]
423   - [...]
424
4256. Current Work:
426   [Precise description of current work]
427
4287. Optional Next Step:
429   [Optional Next step to take]
430
431</summary>
432</example>
433
434Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
435
436There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include:
437<example>
438## Compact Instructions
439When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them.
440</example>
441
442<example>
443# Summary instructions
444When you are using compact - please focus on test output and code changes. Include file reads verbatim.
445</example>"#;
446
447/// A conversation history
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Conversation {
450    pub messages: Vec<Message>,
451    pub working_directory: PathBuf,
452    /// The ID of the currently active message (head of the selected branch).
453    /// None means use last message semantics for backward compatibility.
454    pub active_message_id: Option<String>,
455}
456
457impl Default for Conversation {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463impl Conversation {
464    pub fn new() -> Self {
465        Self {
466            messages: Vec::new(),
467            working_directory: PathBuf::new(),
468            active_message_id: None,
469        }
470    }
471
472    pub fn add_message(&mut self, message: Message) {
473        self.active_message_id = Some(message.id().to_string());
474        self.messages.push(message);
475    }
476
477    pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
478        debug!(target: "conversation::add_message", "Adding message: {:?}", message_data);
479        self.messages.push(Message {
480            data: message_data,
481            id: Message::generate_id("", Message::current_timestamp()),
482            timestamp: Message::current_timestamp(),
483            parent_message_id: self.active_message_id.clone(),
484        });
485        self.active_message_id = Some(self.messages.last().unwrap().id().to_string());
486        self.messages.last().unwrap()
487    }
488
489    pub fn clear(&mut self) {
490        debug!(target:"conversation::clear", "Clearing conversation");
491        self.messages.clear();
492        self.active_message_id = None;
493    }
494
495    /// Find the tool name by its ID by searching through assistant messages with tool calls
496    pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
497        for message in self.messages.iter() {
498            if let MessageData::Assistant { content, .. } = &message.data {
499                for content_block in content {
500                    if let AssistantContent::ToolCall { tool_call } = content_block {
501                        if tool_call.id == tool_id {
502                            return Some(tool_call.name.clone());
503                        }
504                    }
505                }
506            }
507        }
508        None
509    }
510
511    /// Compact the conversation by summarizing older messages in the active thread
512    pub async fn compact(
513        &mut self,
514        api_client: &ApiClient,
515        model: Model,
516        token: CancellationToken,
517    ) -> crate::error::Result<CompactResult> {
518        // Get only the active thread
519        let thread = self.get_active_thread();
520
521        // Skip if we don't have enough messages to compact
522        if thread.len() < 10 {
523            return Ok(CompactResult::InsufficientMessages);
524        }
525
526        // Build prompt from active thread only
527        let mut prompt_messages: Vec<Message> = thread.into_iter().cloned().collect();
528        let last_msg_id = prompt_messages.last().map(|m| m.id().to_string());
529
530        prompt_messages.push(Message {
531            data: MessageData::User {
532                content: vec![UserContent::Text {
533                    text: SUMMARY_PROMPT.to_string(),
534                }],
535            },
536            timestamp: Message::current_timestamp(),
537            id: Message::generate_id("user", Message::current_timestamp()),
538            parent_message_id: last_msg_id.clone(),
539        });
540
541        let summary = tokio::select! {
542            biased;
543            result = api_client.complete(
544                model,
545                prompt_messages,
546                None,
547                None,
548                token.clone(),
549            ) => result.map_err(crate::error::Error::Api)?,
550            _ = token.cancelled() => {
551                return Ok(CompactResult::Cancelled);
552            }
553        };
554
555        let summary_text = summary.extract_text();
556
557        // Create a summary marker message (DO NOT clear messages)
558        let timestamp = Message::current_timestamp();
559        let summary_id = Message::generate_id("user", timestamp);
560
561        // Add the summary as a user message continuing the active thread
562        let summary_message = Message {
563            data: MessageData::User {
564                content: vec![UserContent::Text {
565                    text: format!("[COMPACTED SUMMARY]\n\n{summary_text}"),
566                }],
567            },
568            timestamp,
569            id: summary_id.clone(),
570            parent_message_id: None,
571        };
572
573        self.messages.push(summary_message);
574
575        // Update active_message_id to the summary marker
576        self.active_message_id = Some(summary_id);
577
578        Ok(CompactResult::Success(summary_text))
579    }
580
581    /// Edit a message non-destructively by creating a new branch.
582    /// Returns the ID of the new message if successful.
583    pub fn edit_message(
584        &mut self,
585        message_id: &str,
586        new_content: Vec<UserContent>,
587    ) -> Option<String> {
588        // Find the message to edit
589        let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
590
591        // Only allow editing user messages for now
592        if !matches!(&message_to_edit.data, MessageData::User { .. }) {
593            return None;
594        }
595
596        // Get the parent_message_id from the original message
597        let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
598
599        // Create the new message as a branch from the same parent
600        let new_message_id = Message::generate_id("user", Message::current_timestamp());
601        let edited_message = Message {
602            data: MessageData::User {
603                content: new_content,
604            },
605            timestamp: Message::current_timestamp(),
606            id: new_message_id.clone(),
607            parent_message_id: parent_id,
608        };
609
610        // Add the edited message (original remains in history)
611        self.messages.push(edited_message);
612
613        // Update active_message_id to the new branch head
614        self.active_message_id = Some(new_message_id.clone());
615
616        Some(new_message_id)
617    }
618
619    /// Switch to another branch by setting active_message_id
620    pub fn checkout(&mut self, message_id: &str) -> bool {
621        // Verify the message exists
622        if self.messages.iter().any(|m| m.id() == message_id) {
623            self.active_message_id = Some(message_id.to_string());
624            true
625        } else {
626            false
627        }
628    }
629
630    /// Get messages in the currently active thread
631    pub fn get_active_thread(&self) -> Vec<&Message> {
632        if self.messages.is_empty() {
633            return Vec::new();
634        }
635
636        // Determine the head of the active thread
637        let head_id = if let Some(ref active_id) = self.active_message_id {
638            // Use the explicitly set active message
639            active_id.as_str()
640        } else {
641            // Backward compatibility: use last message
642            self.messages.last().map(|m| m.id()).unwrap_or("")
643        };
644
645        // Find the head message
646        let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
647        if current_msg.is_none() {
648            // If active_message_id is invalid, fall back to last message
649            current_msg = self.messages.last();
650        }
651
652        let mut result = Vec::new();
653        let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
654
655        // Walk backwards from head to root
656        while let Some(msg) = current_msg {
657            result.push(msg);
658
659            // Find parent message using the id_map
660            current_msg = if let Some(parent_id) = msg.parent_message_id() {
661                id_map.get(parent_id).copied()
662            } else {
663                None
664            };
665        }
666
667        result.reverse();
668
669        debug!(
670            "Active thread: [{}]",
671            result
672                .iter()
673                .map(|msg| msg.id())
674                .collect::<Vec<_>>()
675                .join(", ")
676        );
677        result
678    }
679
680    /// Get messages in the current active branch by following parent links from the last message
681    /// This is a thin wrapper around get_active_thread for backward compatibility
682    pub fn get_thread_messages(&self) -> Vec<&Message> {
683        self.get_active_thread()
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use crate::app::conversation::{
690        AssistantContent, Conversation, Message, MessageData, UserContent,
691    };
692
693    /// Helper function to create a user message for testing
694    fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
695        Message {
696            data: MessageData::User {
697                content: vec![UserContent::Text {
698                    text: content.to_string(),
699                }],
700            },
701            timestamp: Message::current_timestamp(),
702            id: id.to_string(),
703            parent_message_id: parent_id.map(String::from),
704        }
705    }
706
707    /// Helper function to create an assistant message for testing
708    fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
709        Message {
710            data: MessageData::Assistant {
711                content: vec![AssistantContent::Text {
712                    text: content.to_string(),
713                }],
714            },
715            timestamp: Message::current_timestamp(),
716            id: id.to_string(),
717            parent_message_id: parent_id.map(String::from),
718        }
719    }
720
721    #[test]
722    fn test_editing_message_in_the_middle_of_conversation() {
723        let mut conversation = Conversation::new();
724
725        // 1. Build an initial conversation
726        let msg1 = create_user_message("msg1", None, "What is Rust?");
727        conversation.add_message(msg1.clone());
728
729        let msg2 =
730            create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
731        conversation.add_message(msg2.clone());
732
733        let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
734        conversation.add_message(msg3.clone());
735
736        let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
737        conversation.add_message(msg4.clone());
738
739        // 2. Edit the *first* user message
740        let edited_id = conversation
741            .edit_message(
742                "msg1",
743                vec![UserContent::Text {
744                    text: "What is Golang?".to_string(),
745                }],
746            )
747            .unwrap();
748
749        // 3. Check the state after editing
750        let messages_after_edit = conversation.get_thread_messages();
751        let message_ids_after_edit: Vec<&str> =
752            messages_after_edit.iter().map(|m| m.id()).collect();
753
754        assert_eq!(
755            message_ids_after_edit.len(),
756            1,
757            "Active thread should only show the edited message"
758        );
759        assert_eq!(message_ids_after_edit[0], edited_id.as_str());
760
761        // Verify original branch still exists in messages
762        assert!(conversation.messages.iter().any(|m| m.id() == "msg1"));
763        assert!(conversation.messages.iter().any(|m| m.id() == "msg2"));
764        assert!(conversation.messages.iter().any(|m| m.id() == "msg3"));
765        assert!(conversation.messages.iter().any(|m| m.id() == "msg4"));
766
767        // 4. Add a new message to the new branch of conversation
768        let msg5 = create_assistant_message(
769            "msg5",
770            Some(&edited_id),
771            "A systems programming language from Google.",
772        );
773        conversation.add_message(msg5.clone());
774
775        // 5. Check the final state of the conversation
776        let final_messages = conversation.get_thread_messages();
777        let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
778
779        assert_eq!(
780            final_messages.len(),
781            2,
782            "Should have the edited message and the new response."
783        );
784        assert_eq!(final_message_ids[0], edited_id.as_str());
785        assert_eq!(final_message_ids[1], "msg5");
786    }
787
788    #[test]
789    fn test_get_thread_messages_after_edit() {
790        let mut conversation = Conversation::new();
791
792        // 1. Initial conversation
793        let msg1 = create_user_message("msg1", None, "hello");
794        conversation.add_message(msg1.clone());
795
796        let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
797        conversation.add_message(msg2.clone());
798
799        // This is the message that will be "edited out"
800        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
801        conversation.add_message(msg3_original.clone());
802
803        // 2. Edit the last user message ("thanks")
804        let edited_id = conversation
805            .edit_message(
806                "msg3_original",
807                vec![UserContent::Text {
808                    text: "how are you".to_string(),
809                }],
810            )
811            .unwrap();
812
813        // 3. Add a new assistant message to the new branch
814        let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
815        conversation.add_message(msg4.clone());
816
817        // 4. Get messages for the current thread
818        let thread_messages = conversation.get_thread_messages();
819
820        // 5. Assertions
821        let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
822
823        // Should contain the root, the first assistant message, the *new* user message, and the final assistant message
824        assert_eq!(
825            thread_message_ids.len(),
826            4,
827            "Should have 4 messages in the current thread"
828        );
829        assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
830        assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
831        assert!(
832            thread_message_ids.contains(&edited_id.as_str()),
833            "Should contain the edited message"
834        );
835        assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
836
837        // Verify the original message still exists in the full message list
838        assert!(
839            conversation
840                .messages
841                .iter()
842                .any(|m| m.id() == "msg3_original"),
843            "Original message should still exist in conversation history"
844        );
845    }
846
847    #[test]
848    fn test_get_thread_messages_filters_other_branches() {
849        let mut conversation = Conversation::new();
850
851        // 1. Initial conversation: "hi"
852        let msg1 = create_user_message("msg1", None, "hi");
853        conversation.add_message(msg1.clone());
854
855        let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
856        conversation.add_message(msg2.clone());
857
858        // 2. User says "thanks" (this will be edited out)
859        let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
860        conversation.add_message(msg3_original.clone());
861
862        let msg4_original =
863            create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
864        conversation.add_message(msg4_original.clone());
865
866        // 3. Edit the "thanks" message to "how are you"
867        let edited_id = conversation
868            .edit_message(
869                "msg3_original",
870                vec![UserContent::Text {
871                    text: "how are you".to_string(),
872                }],
873            )
874            .unwrap();
875
876        // 4. Add assistant response in the new branch
877        let msg4_new = create_assistant_message(
878            "msg4_new",
879            Some(&edited_id),
880            "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
881        );
882        conversation.add_message(msg4_new.clone());
883
884        // 5. User asks "what messages have I sent you?"
885        let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
886        conversation.add_message(msg5.clone());
887
888        // 6. Get messages for the current thread - this should NOT include "thanks"
889        let thread_messages = conversation.get_thread_messages();
890
891        // Extract the user messages
892        let user_messages: Vec<String> = thread_messages
893            .iter()
894            .filter(|m| matches!(m.data, MessageData::User { .. }))
895            .map(|m| m.extract_text())
896            .collect();
897
898        println!("User messages seen: {user_messages:?}");
899
900        // Assertions
901        assert_eq!(
902            user_messages.len(),
903            3,
904            "Should have exactly 3 user messages"
905        );
906        assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
907        assert_eq!(
908            user_messages[1], "how are you",
909            "Second message should be 'how are you' (edited)"
910        );
911        assert_eq!(
912            user_messages[2], "what messages have I sent you?",
913            "Third message should be the question"
914        );
915
916        // CRITICAL: Should NOT contain "thanks" from the other branch
917        assert!(
918            !user_messages.contains(&"thanks".to_string()),
919            "Should NOT contain 'thanks' from the non-active branch"
920        );
921
922        // But the original message should still exist in the full conversation
923        assert!(
924            conversation
925                .messages
926                .iter()
927                .any(|m| m.id() == "msg3_original"),
928            "Original 'thanks' message should still exist in conversation history"
929        );
930    }
931
932    #[test]
933    fn test_checkout_branch() {
934        let mut conversation = Conversation::new();
935
936        // Create initial conversation
937        let msg1 = create_user_message("msg1", None, "hello");
938        conversation.add_message(msg1.clone());
939
940        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
941        conversation.add_message(msg2.clone());
942
943        // Edit to create a branch
944        let edited_id = conversation
945            .edit_message(
946                "msg1",
947                vec![UserContent::Text {
948                    text: "goodbye".to_string(),
949                }],
950            )
951            .unwrap();
952
953        // Verify we're on the new branch
954        assert_eq!(conversation.active_message_id, Some(edited_id.clone()));
955        let thread = conversation.get_active_thread();
956        assert_eq!(thread.len(), 1);
957        assert_eq!(thread[0].id(), edited_id);
958
959        // Checkout the original branch
960        assert!(conversation.checkout("msg2"));
961        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
962
963        // Verify we're back on the original branch
964        let thread = conversation.get_active_thread();
965        assert_eq!(thread.len(), 2);
966        assert_eq!(thread[0].id(), "msg1");
967        assert_eq!(thread[1].id(), "msg2");
968
969        // Try to checkout non-existent message
970        assert!(!conversation.checkout("non-existent"));
971        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
972    }
973
974    #[test]
975    fn test_active_message_id_tracking() {
976        let mut conversation = Conversation::new();
977
978        // Initially no active message
979        assert_eq!(conversation.active_message_id, None);
980
981        // Add root message - should become active
982        let msg1 = create_user_message("msg1", None, "hello");
983        conversation.add_message(msg1);
984        assert_eq!(conversation.active_message_id, Some("msg1".to_string()));
985
986        // Add response - should update active
987        let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
988        conversation.add_message(msg2);
989        assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
990
991        // Add another branch from msg1
992        let msg3 = create_user_message("msg3", Some("msg1"), "different question");
993        conversation.add_message(msg3);
994        assert_eq!(conversation.active_message_id, Some("msg3".to_string()));
995
996        // Continue from active
997        let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
998        conversation.add_message(msg4);
999        assert_eq!(conversation.active_message_id, Some("msg4".to_string()));
1000    }
1001}