1use crate::api::Client as ApiClient;
2use crate::api::Model;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::path::PathBuf;
7use std::str::FromStr;
8use std::time::{SystemTime, UNIX_EPOCH};
9use steer_tools::ToolCall;
10pub use steer_tools::result::ToolResult;
11use tracing::debug;
12
13use strum_macros::Display;
14use tokio_util::sync::CancellationToken;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "result_type", rename_all = "snake_case")]
19pub enum CompactResult {
20 Success(String),
22 Cancelled,
24 InsufficientMessages,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(tag = "response_type", rename_all = "snake_case")]
31pub enum CommandResponse {
32 Text(String),
34 Compact(CompactResult),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(tag = "command_type", rename_all = "snake_case")]
41pub enum AppCommandType {
42 Model { target: Option<String> },
44 Clear,
46 Compact,
48}
49
50impl AppCommandType {
51 pub fn parse(input: &str) -> Result<Self, SlashCommandError> {
53 let command = input.trim();
55 let command = command.strip_prefix('/').unwrap_or(command);
56
57 let parts: Vec<&str> = command.split_whitespace().collect();
59 if parts.is_empty() {
60 return Err(SlashCommandError::InvalidFormat(
61 "Empty command".to_string(),
62 ));
63 }
64
65 match parts[0] {
66 "model" => {
67 let target = if parts.len() > 1 {
68 Some(parts[1..].join(" "))
69 } else {
70 None
71 };
72 Ok(AppCommandType::Model { target })
73 }
74 "clear" => Ok(AppCommandType::Clear),
75 "compact" => Ok(AppCommandType::Compact),
76 cmd => Err(SlashCommandError::UnknownCommand(cmd.to_string())),
77 }
78 }
79
80 pub fn as_command_str(&self) -> String {
82 match self {
83 AppCommandType::Model { target } => {
84 if let Some(model) = target {
85 format!("model {model}")
86 } else {
87 "model".to_string()
88 }
89 }
90 AppCommandType::Clear => "clear".to_string(),
91 AppCommandType::Compact => "compact".to_string(),
92 }
93 }
94}
95
96impl fmt::Display for AppCommandType {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 write!(f, "/{}", self.as_command_str())
99 }
100}
101
102impl FromStr for AppCommandType {
103 type Err = SlashCommandError;
104
105 fn from_str(s: &str) -> Result<Self, Self::Err> {
106 Self::parse(s)
107 }
108}
109
110#[derive(Debug, thiserror::Error)]
112pub enum SlashCommandError {
113 #[error("Unknown command: {0}")]
114 UnknownCommand(String),
115 #[error("Invalid command format: {0}")]
116 InvalidFormat(String),
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Copy, Display)]
121pub enum Role {
122 User,
123 Assistant,
124 Tool,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129#[serde(tag = "type", rename_all = "snake_case")]
130pub enum UserContent {
131 Text {
132 text: String,
133 },
134 CommandExecution {
135 command: String,
136 stdout: String,
137 stderr: String,
138 exit_code: i32,
139 },
140 AppCommand {
141 command: AppCommandType,
142 response: Option<CommandResponse>,
143 },
144 }
146
147impl UserContent {
148 pub fn format_command_execution_as_xml(
149 command: &str,
150 stdout: &str,
151 stderr: &str,
152 exit_code: i32,
153 ) -> String {
154 format!(
155 r#"<executed_command>
156 <command>{command}</command>
157 <stdout>{stdout}</stdout>
158 <stderr>{stderr}</stderr>
159 <exit_code>{exit_code}</exit_code>
160</executed_command>"#
161 )
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167#[serde(tag = "thought_type")]
168pub enum ThoughtContent {
169 #[serde(rename = "simple")]
171 Simple { text: String },
172 #[serde(rename = "signed")]
174 Signed { text: String, signature: String },
175 #[serde(rename = "redacted")]
177 Redacted { data: String },
178}
179
180impl ThoughtContent {
181 pub fn display_text(&self) -> String {
183 match self {
184 ThoughtContent::Simple { text } => text.clone(),
185 ThoughtContent::Signed { text, .. } => text.clone(),
186 ThoughtContent::Redacted { .. } => "[Redacted Thinking]".to_string(),
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193#[serde(tag = "type", rename_all = "snake_case")]
194pub enum AssistantContent {
195 Text { text: String },
196 ToolCall { tool_call: ToolCall },
197 Thought { thought: ThoughtContent },
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct Message {
202 pub timestamp: u64,
203 pub id: String,
204 pub parent_message_id: Option<String>,
205 pub data: MessageData,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210#[serde(tag = "role", rename_all = "lowercase")]
211pub enum MessageData {
212 User {
213 content: Vec<UserContent>,
214 },
215 Assistant {
216 content: Vec<AssistantContent>,
217 },
218 Tool {
219 tool_use_id: String,
220 result: ToolResult,
221 },
222}
223
224impl Message {
225 pub fn role(&self) -> Role {
226 match &self.data {
227 MessageData::User { .. } => Role::User,
228 MessageData::Assistant { .. } => Role::Assistant,
229 MessageData::Tool { .. } => Role::Tool,
230 }
231 }
232
233 pub fn id(&self) -> &str {
234 &self.id
235 }
236
237 pub fn timestamp(&self) -> u64 {
238 self.timestamp
239 }
240
241 pub fn parent_message_id(&self) -> Option<&str> {
242 self.parent_message_id.as_deref()
243 }
244
245 pub fn current_timestamp() -> u64 {
247 SystemTime::now()
248 .duration_since(UNIX_EPOCH)
249 .expect("Time went backwards")
250 .as_secs()
251 }
252
253 pub fn generate_id(prefix: &str, _timestamp: u64) -> String {
255 use uuid::Uuid;
256 format!("{}_{}", prefix, Uuid::now_v7())
257 }
258
259 pub fn extract_text(&self) -> String {
261 match &self.data {
262 MessageData::User { content } => content
263 .iter()
264 .filter_map(|c| match c {
265 UserContent::Text { text } => Some(text.clone()),
266 UserContent::CommandExecution { stdout, .. } => Some(stdout.clone()),
267 UserContent::AppCommand { response, .. } => {
268 response.as_ref().map(|r| match r {
269 CommandResponse::Text(t) => t.clone(),
270 CommandResponse::Compact(CompactResult::Success(s)) => s.clone(),
271 _ => String::new(),
272 })
273 }
274 })
275 .collect::<Vec<_>>()
276 .join("\n"),
277 MessageData::Assistant { content } => content
278 .iter()
279 .filter_map(|c| match c {
280 AssistantContent::Text { text } => Some(text.clone()),
281 _ => None,
282 })
283 .collect::<Vec<_>>()
284 .join("\n"),
285 MessageData::Tool { result, .. } => result.llm_format(),
286 }
287 }
288
289 pub fn content_string(&self) -> String {
291 match &self.data {
292 MessageData::User { content } => content
293 .iter()
294 .map(|c| match c {
295 UserContent::Text { text } => text.clone(),
296 UserContent::CommandExecution {
297 command,
298 stdout,
299 stderr,
300 exit_code,
301 } => {
302 let mut output = format!("$ {command}\n{stdout}");
303 if *exit_code != 0 {
304 output.push_str(&format!("\nExit code: {exit_code}"));
305 if !stderr.is_empty() {
306 output.push_str(&format!("\nError: {stderr}"));
307 }
308 }
309 output
310 }
311 UserContent::AppCommand { command, response } => {
312 if let Some(resp) = response {
313 let text = match resp {
314 CommandResponse::Text(msg) => msg.clone(),
315 CommandResponse::Compact(result) => match result {
316 CompactResult::Success(summary) => summary.clone(),
317 CompactResult::Cancelled => {
318 "Compact command cancelled.".to_string()
319 }
320 CompactResult::InsufficientMessages => {
321 "Not enough messages to compact (minimum 10 required)."
322 .to_string()
323 }
324 },
325 };
326 format!("/{}\n{}", command.as_command_str(), text)
327 } else {
328 format!("/{}", command.as_command_str())
329 }
330 }
331 })
332 .collect::<Vec<_>>()
333 .join("\n"),
334 MessageData::Assistant { content } => content
335 .iter()
336 .map(|c| match c {
337 AssistantContent::Text { text } => text.clone(),
338 AssistantContent::ToolCall { tool_call } => {
339 format!("[Tool Call: {}]", tool_call.name)
340 }
341 AssistantContent::Thought { thought } => {
342 format!("[Thought: {}]", thought.display_text())
343 }
344 })
345 .collect::<Vec<_>>()
346 .join("\n"),
347 MessageData::Tool { result, .. } => {
348 let result_type = match result {
350 ToolResult::Search(_) => "Search Result",
351 ToolResult::FileList(_) => "File List",
352 ToolResult::FileContent(_) => "File Content",
353 ToolResult::Edit(_) => "Edit Result",
354 ToolResult::Bash(_) => "Bash Result",
355 ToolResult::Glob(_) => "Glob Result",
356 ToolResult::TodoRead(_) => "Todo List",
357 ToolResult::TodoWrite(_) => "Todo Update",
358 ToolResult::Fetch(_) => "Fetch Result",
359 ToolResult::Agent(_) => "Agent Result",
360 ToolResult::External(_) => "External Tool Result",
361 ToolResult::Error(_) => "Error",
362 };
363 format!("[Tool Result: {result_type}]")
364 }
365 }
366 }
367}
368
369const SUMMARY_PROMPT: &str = r#"Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
370This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context.
371
372Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
373
3741. Chronologically analyze each message and section of the conversation. For each section thoroughly identify:
375 - The user's explicit requests and intents
376 - Your approach to addressing the user's requests
377 - Key decisions, technical concepts and code patterns
378 - Specific details like file names, full code snippets, function signatures, file edits, etc
3792. Double-check for technical accuracy and completeness, addressing each required element thoroughly.
380
381Your summary should include the following sections:
382
3831. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
3842. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed.
3853. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
3864. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
3875. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
3886. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
3897. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests without confirming with the user first.
390 If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
391
392Here's an example of how your output should be structured:
393
394<example>
395<analysis>
396[Your thought process, ensuring all points are covered thoroughly and accurately]
397</analysis>
398
399<summary>
4001. Primary Request and Intent:
401 [Detailed description]
402
4032. Key Technical Concepts:
404 - [Concept 1]
405 - [Concept 2]
406 - [...]
407
4083. Files and Code Sections:
409 - [File Name 1]
410 - [Summary of why this file is important]
411 - [Summary of the changes made to this file, if any]
412 - [Important Code Snippet]
413 - [File Name 2]
414 - [Important Code Snippet]
415 - [...]
416
4174. Problem Solving:
418 [Description of solved problems and ongoing troubleshooting]
419
4205. Pending Tasks:
421 - [Task 1]
422 - [Task 2]
423 - [...]
424
4256. Current Work:
426 [Precise description of current work]
427
4287. Optional Next Step:
429 [Optional Next step to take]
430
431</summary>
432</example>
433
434Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
435
436There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include:
437<example>
438## Compact Instructions
439When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them.
440</example>
441
442<example>
443# Summary instructions
444When you are using compact - please focus on test output and code changes. Include file reads verbatim.
445</example>"#;
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Conversation {
450 pub messages: Vec<Message>,
451 pub working_directory: PathBuf,
452 pub active_message_id: Option<String>,
455}
456
457impl Default for Conversation {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463impl Conversation {
464 pub fn new() -> Self {
465 Self {
466 messages: Vec::new(),
467 working_directory: PathBuf::new(),
468 active_message_id: None,
469 }
470 }
471
472 pub fn add_message(&mut self, message: Message) {
473 self.active_message_id = Some(message.id().to_string());
474 self.messages.push(message);
475 }
476
477 pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
478 debug!(target: "conversation::add_message", "Adding message: {:?}", message_data);
479 self.messages.push(Message {
480 data: message_data,
481 id: Message::generate_id("", Message::current_timestamp()),
482 timestamp: Message::current_timestamp(),
483 parent_message_id: self.active_message_id.clone(),
484 });
485 self.active_message_id = Some(self.messages.last().unwrap().id().to_string());
486 self.messages.last().unwrap()
487 }
488
489 pub fn clear(&mut self) {
490 debug!(target:"conversation::clear", "Clearing conversation");
491 self.messages.clear();
492 self.active_message_id = None;
493 }
494
495 pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
497 for message in self.messages.iter() {
498 if let MessageData::Assistant { content, .. } = &message.data {
499 for content_block in content {
500 if let AssistantContent::ToolCall { tool_call } = content_block {
501 if tool_call.id == tool_id {
502 return Some(tool_call.name.clone());
503 }
504 }
505 }
506 }
507 }
508 None
509 }
510
511 pub async fn compact(
513 &mut self,
514 api_client: &ApiClient,
515 model: Model,
516 token: CancellationToken,
517 ) -> crate::error::Result<CompactResult> {
518 let thread = self.get_active_thread();
520
521 if thread.len() < 10 {
523 return Ok(CompactResult::InsufficientMessages);
524 }
525
526 let mut prompt_messages: Vec<Message> = thread.into_iter().cloned().collect();
528 let last_msg_id = prompt_messages.last().map(|m| m.id().to_string());
529
530 prompt_messages.push(Message {
531 data: MessageData::User {
532 content: vec![UserContent::Text {
533 text: SUMMARY_PROMPT.to_string(),
534 }],
535 },
536 timestamp: Message::current_timestamp(),
537 id: Message::generate_id("user", Message::current_timestamp()),
538 parent_message_id: last_msg_id.clone(),
539 });
540
541 let summary = tokio::select! {
542 biased;
543 result = api_client.complete(
544 model,
545 prompt_messages,
546 None,
547 None,
548 token.clone(),
549 ) => result.map_err(crate::error::Error::Api)?,
550 _ = token.cancelled() => {
551 return Ok(CompactResult::Cancelled);
552 }
553 };
554
555 let summary_text = summary.extract_text();
556
557 let timestamp = Message::current_timestamp();
559 let summary_id = Message::generate_id("user", timestamp);
560
561 let summary_message = Message {
563 data: MessageData::User {
564 content: vec![UserContent::Text {
565 text: format!("[COMPACTED SUMMARY]\n\n{summary_text}"),
566 }],
567 },
568 timestamp,
569 id: summary_id.clone(),
570 parent_message_id: last_msg_id, };
572
573 self.messages.push(summary_message);
574
575 self.active_message_id = Some(summary_id);
577
578 Ok(CompactResult::Success(summary_text))
579 }
580
581 pub fn edit_message(
584 &mut self,
585 message_id: &str,
586 new_content: Vec<UserContent>,
587 ) -> Option<String> {
588 let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
590
591 if !matches!(&message_to_edit.data, MessageData::User { .. }) {
593 return None;
594 }
595
596 let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
598
599 let new_message_id = Message::generate_id("user", Message::current_timestamp());
601 let edited_message = Message {
602 data: MessageData::User {
603 content: new_content,
604 },
605 timestamp: Message::current_timestamp(),
606 id: new_message_id.clone(),
607 parent_message_id: parent_id,
608 };
609
610 self.messages.push(edited_message);
612
613 self.active_message_id = Some(new_message_id.clone());
615
616 Some(new_message_id)
617 }
618
619 pub fn checkout(&mut self, message_id: &str) -> bool {
621 if self.messages.iter().any(|m| m.id() == message_id) {
623 self.active_message_id = Some(message_id.to_string());
624 true
625 } else {
626 false
627 }
628 }
629
630 pub fn get_active_thread(&self) -> Vec<&Message> {
632 if self.messages.is_empty() {
633 return Vec::new();
634 }
635
636 let head_id = if let Some(ref active_id) = self.active_message_id {
638 active_id.as_str()
640 } else {
641 self.messages.last().map(|m| m.id()).unwrap_or("")
643 };
644
645 let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
647 if current_msg.is_none() {
648 current_msg = self.messages.last();
650 }
651
652 let mut result = Vec::new();
653 let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
654
655 while let Some(msg) = current_msg {
657 result.push(msg);
658
659 current_msg = if let Some(parent_id) = msg.parent_message_id() {
661 id_map.get(parent_id).copied()
662 } else {
663 None
664 };
665 }
666
667 result.reverse();
668
669 debug!(
670 "Active thread: [{}]",
671 result
672 .iter()
673 .map(|msg| msg.id())
674 .collect::<Vec<_>>()
675 .join(", ")
676 );
677 result
678 }
679
680 pub fn get_thread_messages(&self) -> Vec<&Message> {
683 self.get_active_thread()
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use crate::app::conversation::{
690 AssistantContent, Conversation, Message, MessageData, UserContent,
691 };
692
693 fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
695 Message {
696 data: MessageData::User {
697 content: vec![UserContent::Text {
698 text: content.to_string(),
699 }],
700 },
701 timestamp: Message::current_timestamp(),
702 id: id.to_string(),
703 parent_message_id: parent_id.map(String::from),
704 }
705 }
706
707 fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
709 Message {
710 data: MessageData::Assistant {
711 content: vec![AssistantContent::Text {
712 text: content.to_string(),
713 }],
714 },
715 timestamp: Message::current_timestamp(),
716 id: id.to_string(),
717 parent_message_id: parent_id.map(String::from),
718 }
719 }
720
721 #[test]
722 fn test_editing_message_in_the_middle_of_conversation() {
723 let mut conversation = Conversation::new();
724
725 let msg1 = create_user_message("msg1", None, "What is Rust?");
727 conversation.add_message(msg1.clone());
728
729 let msg2 =
730 create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
731 conversation.add_message(msg2.clone());
732
733 let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
734 conversation.add_message(msg3.clone());
735
736 let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
737 conversation.add_message(msg4.clone());
738
739 let edited_id = conversation
741 .edit_message(
742 "msg1",
743 vec![UserContent::Text {
744 text: "What is Golang?".to_string(),
745 }],
746 )
747 .unwrap();
748
749 let messages_after_edit = conversation.get_thread_messages();
751 let message_ids_after_edit: Vec<&str> =
752 messages_after_edit.iter().map(|m| m.id()).collect();
753
754 assert_eq!(
755 message_ids_after_edit.len(),
756 1,
757 "Active thread should only show the edited message"
758 );
759 assert_eq!(message_ids_after_edit[0], edited_id.as_str());
760
761 assert!(conversation.messages.iter().any(|m| m.id() == "msg1"));
763 assert!(conversation.messages.iter().any(|m| m.id() == "msg2"));
764 assert!(conversation.messages.iter().any(|m| m.id() == "msg3"));
765 assert!(conversation.messages.iter().any(|m| m.id() == "msg4"));
766
767 let msg5 = create_assistant_message(
769 "msg5",
770 Some(&edited_id),
771 "A systems programming language from Google.",
772 );
773 conversation.add_message(msg5.clone());
774
775 let final_messages = conversation.get_thread_messages();
777 let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
778
779 assert_eq!(
780 final_messages.len(),
781 2,
782 "Should have the edited message and the new response."
783 );
784 assert_eq!(final_message_ids[0], edited_id.as_str());
785 assert_eq!(final_message_ids[1], "msg5");
786 }
787
788 #[test]
789 fn test_get_thread_messages_after_edit() {
790 let mut conversation = Conversation::new();
791
792 let msg1 = create_user_message("msg1", None, "hello");
794 conversation.add_message(msg1.clone());
795
796 let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
797 conversation.add_message(msg2.clone());
798
799 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
801 conversation.add_message(msg3_original.clone());
802
803 let edited_id = conversation
805 .edit_message(
806 "msg3_original",
807 vec![UserContent::Text {
808 text: "how are you".to_string(),
809 }],
810 )
811 .unwrap();
812
813 let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
815 conversation.add_message(msg4.clone());
816
817 let thread_messages = conversation.get_thread_messages();
819
820 let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
822
823 assert_eq!(
825 thread_message_ids.len(),
826 4,
827 "Should have 4 messages in the current thread"
828 );
829 assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
830 assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
831 assert!(
832 thread_message_ids.contains(&edited_id.as_str()),
833 "Should contain the edited message"
834 );
835 assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
836
837 assert!(
839 conversation
840 .messages
841 .iter()
842 .any(|m| m.id() == "msg3_original"),
843 "Original message should still exist in conversation history"
844 );
845 }
846
847 #[test]
848 fn test_get_thread_messages_filters_other_branches() {
849 let mut conversation = Conversation::new();
850
851 let msg1 = create_user_message("msg1", None, "hi");
853 conversation.add_message(msg1.clone());
854
855 let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
856 conversation.add_message(msg2.clone());
857
858 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
860 conversation.add_message(msg3_original.clone());
861
862 let msg4_original =
863 create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
864 conversation.add_message(msg4_original.clone());
865
866 let edited_id = conversation
868 .edit_message(
869 "msg3_original",
870 vec![UserContent::Text {
871 text: "how are you".to_string(),
872 }],
873 )
874 .unwrap();
875
876 let msg4_new = create_assistant_message(
878 "msg4_new",
879 Some(&edited_id),
880 "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
881 );
882 conversation.add_message(msg4_new.clone());
883
884 let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
886 conversation.add_message(msg5.clone());
887
888 let thread_messages = conversation.get_thread_messages();
890
891 let user_messages: Vec<String> = thread_messages
893 .iter()
894 .filter(|m| matches!(m.data, MessageData::User { .. }))
895 .map(|m| m.extract_text())
896 .collect();
897
898 println!("User messages seen: {user_messages:?}");
899
900 assert_eq!(
902 user_messages.len(),
903 3,
904 "Should have exactly 3 user messages"
905 );
906 assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
907 assert_eq!(
908 user_messages[1], "how are you",
909 "Second message should be 'how are you' (edited)"
910 );
911 assert_eq!(
912 user_messages[2], "what messages have I sent you?",
913 "Third message should be the question"
914 );
915
916 assert!(
918 !user_messages.contains(&"thanks".to_string()),
919 "Should NOT contain 'thanks' from the non-active branch"
920 );
921
922 assert!(
924 conversation
925 .messages
926 .iter()
927 .any(|m| m.id() == "msg3_original"),
928 "Original 'thanks' message should still exist in conversation history"
929 );
930 }
931
932 #[test]
933 fn test_checkout_branch() {
934 let mut conversation = Conversation::new();
935
936 let msg1 = create_user_message("msg1", None, "hello");
938 conversation.add_message(msg1.clone());
939
940 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
941 conversation.add_message(msg2.clone());
942
943 let edited_id = conversation
945 .edit_message(
946 "msg1",
947 vec![UserContent::Text {
948 text: "goodbye".to_string(),
949 }],
950 )
951 .unwrap();
952
953 assert_eq!(conversation.active_message_id, Some(edited_id.clone()));
955 let thread = conversation.get_active_thread();
956 assert_eq!(thread.len(), 1);
957 assert_eq!(thread[0].id(), edited_id);
958
959 assert!(conversation.checkout("msg2"));
961 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
962
963 let thread = conversation.get_active_thread();
965 assert_eq!(thread.len(), 2);
966 assert_eq!(thread[0].id(), "msg1");
967 assert_eq!(thread[1].id(), "msg2");
968
969 assert!(!conversation.checkout("non-existent"));
971 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
972 }
973
974 #[test]
975 fn test_active_message_id_tracking() {
976 let mut conversation = Conversation::new();
977
978 assert_eq!(conversation.active_message_id, None);
980
981 let msg1 = create_user_message("msg1", None, "hello");
983 conversation.add_message(msg1);
984 assert_eq!(conversation.active_message_id, Some("msg1".to_string()));
985
986 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
988 conversation.add_message(msg2);
989 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
990
991 let msg3 = create_user_message("msg3", Some("msg1"), "different question");
993 conversation.add_message(msg3);
994 assert_eq!(conversation.active_message_id, Some("msg3".to_string()));
995
996 let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
998 conversation.add_message(msg4);
999 assert_eq!(conversation.active_message_id, Some("msg4".to_string()));
1000 }
1001}