1use crate::api::Client as ApiClient;
2use crate::config::model::ModelId;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt;
6use std::path::PathBuf;
7use std::str::FromStr;
8use std::time::{SystemTime, UNIX_EPOCH};
9use steer_tools::ToolCall;
10pub use steer_tools::result::ToolResult;
11use tracing::debug;
12
13use strum_macros::Display;
14use tokio_util::sync::CancellationToken;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "result_type", rename_all = "snake_case")]
19pub enum CompactResult {
20 Success(String),
22 Cancelled,
24 InsufficientMessages,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(tag = "response_type", rename_all = "snake_case")]
31pub enum CommandResponse {
32 Text(String),
34 Compact(CompactResult),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(tag = "command_type", rename_all = "snake_case")]
41pub enum AppCommandType {
42 Model { target: Option<String> },
44 Clear,
46 Compact,
48}
49
50impl AppCommandType {
51 pub fn parse(input: &str) -> Result<Self, SlashCommandError> {
53 let command = input.trim();
55 let command = command.strip_prefix('/').unwrap_or(command);
56
57 let parts: Vec<&str> = command.split_whitespace().collect();
59 if parts.is_empty() {
60 return Err(SlashCommandError::InvalidFormat(
61 "Empty command".to_string(),
62 ));
63 }
64
65 match parts[0] {
66 "model" => {
67 let target = if parts.len() > 1 {
68 Some(parts[1..].join(" "))
69 } else {
70 None
71 };
72 Ok(AppCommandType::Model { target })
73 }
74 "clear" => Ok(AppCommandType::Clear),
75 "compact" => Ok(AppCommandType::Compact),
76 cmd => Err(SlashCommandError::UnknownCommand(cmd.to_string())),
77 }
78 }
79
80 pub fn as_command_str(&self) -> String {
82 match self {
83 AppCommandType::Model { target } => {
84 if let Some(model) = target {
85 format!("model {model}")
86 } else {
87 "model".to_string()
88 }
89 }
90 AppCommandType::Clear => "clear".to_string(),
91 AppCommandType::Compact => "compact".to_string(),
92 }
93 }
94}
95
96impl fmt::Display for AppCommandType {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 write!(f, "/{}", self.as_command_str())
99 }
100}
101
102impl FromStr for AppCommandType {
103 type Err = SlashCommandError;
104
105 fn from_str(s: &str) -> Result<Self, Self::Err> {
106 Self::parse(s)
107 }
108}
109
110#[derive(Debug, thiserror::Error)]
112pub enum SlashCommandError {
113 #[error("Unknown command: {0}")]
114 UnknownCommand(String),
115 #[error("Invalid command format: {0}")]
116 InvalidFormat(String),
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Copy, Display)]
121pub enum Role {
122 User,
123 Assistant,
124 Tool,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
129#[serde(tag = "type", rename_all = "snake_case")]
130pub enum UserContent {
131 Text {
132 text: String,
133 },
134 CommandExecution {
135 command: String,
136 stdout: String,
137 stderr: String,
138 exit_code: i32,
139 },
140 AppCommand {
141 command: AppCommandType,
142 response: Option<CommandResponse>,
143 },
144 }
146
147impl UserContent {
148 pub fn format_command_execution_as_xml(
149 command: &str,
150 stdout: &str,
151 stderr: &str,
152 exit_code: i32,
153 ) -> String {
154 format!(
155 r#"<executed_command>
156 <command>{command}</command>
157 <stdout>{stdout}</stdout>
158 <stderr>{stderr}</stderr>
159 <exit_code>{exit_code}</exit_code>
160</executed_command>"#
161 )
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167#[serde(tag = "thought_type")]
168pub enum ThoughtContent {
169 #[serde(rename = "simple")]
171 Simple { text: String },
172 #[serde(rename = "signed")]
174 Signed { text: String, signature: String },
175 #[serde(rename = "redacted")]
177 Redacted { data: String },
178}
179
180impl ThoughtContent {
181 pub fn display_text(&self) -> String {
183 match self {
184 ThoughtContent::Simple { text } => text.clone(),
185 ThoughtContent::Signed { text, .. } => text.clone(),
186 ThoughtContent::Redacted { .. } => "[Redacted Thinking]".to_string(),
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193#[serde(tag = "type", rename_all = "snake_case")]
194pub enum AssistantContent {
195 Text { text: String },
196 ToolCall { tool_call: ToolCall },
197 Thought { thought: ThoughtContent },
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct Message {
202 pub timestamp: u64,
203 pub id: String,
204 pub parent_message_id: Option<String>,
205 pub data: MessageData,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210#[serde(tag = "role", rename_all = "lowercase")]
211pub enum MessageData {
212 User {
213 content: Vec<UserContent>,
214 },
215 Assistant {
216 content: Vec<AssistantContent>,
217 },
218 Tool {
219 tool_use_id: String,
220 result: ToolResult,
221 },
222}
223
224impl Message {
225 pub fn role(&self) -> Role {
226 match &self.data {
227 MessageData::User { .. } => Role::User,
228 MessageData::Assistant { .. } => Role::Assistant,
229 MessageData::Tool { .. } => Role::Tool,
230 }
231 }
232
233 pub fn id(&self) -> &str {
234 &self.id
235 }
236
237 pub fn timestamp(&self) -> u64 {
238 self.timestamp
239 }
240
241 pub fn parent_message_id(&self) -> Option<&str> {
242 self.parent_message_id.as_deref()
243 }
244
245 pub fn current_timestamp() -> u64 {
247 SystemTime::now()
248 .duration_since(UNIX_EPOCH)
249 .expect("Time went backwards")
250 .as_secs()
251 }
252
253 pub fn generate_id(prefix: &str, _timestamp: u64) -> String {
255 use uuid::Uuid;
256 format!("{}_{}", prefix, Uuid::now_v7())
257 }
258
259 pub fn extract_text(&self) -> String {
261 match &self.data {
262 MessageData::User { content } => content
263 .iter()
264 .filter_map(|c| match c {
265 UserContent::Text { text } => Some(text.clone()),
266 UserContent::CommandExecution { stdout, .. } => Some(stdout.clone()),
267 UserContent::AppCommand { response, .. } => {
268 response.as_ref().map(|r| match r {
269 CommandResponse::Text(t) => t.clone(),
270 CommandResponse::Compact(CompactResult::Success(s)) => s.clone(),
271 _ => String::new(),
272 })
273 }
274 })
275 .collect::<Vec<_>>()
276 .join("\n"),
277 MessageData::Assistant { content } => content
278 .iter()
279 .filter_map(|c| match c {
280 AssistantContent::Text { text } => Some(text.clone()),
281 _ => None,
282 })
283 .collect::<Vec<_>>()
284 .join("\n"),
285 MessageData::Tool { result, .. } => result.llm_format(),
286 }
287 }
288
289 pub fn content_string(&self) -> String {
291 match &self.data {
292 MessageData::User { content } => content
293 .iter()
294 .map(|c| match c {
295 UserContent::Text { text } => text.clone(),
296 UserContent::CommandExecution {
297 command,
298 stdout,
299 stderr,
300 exit_code,
301 } => {
302 let mut output = format!("$ {command}\n{stdout}");
303 if *exit_code != 0 {
304 output.push_str(&format!("\nExit code: {exit_code}"));
305 if !stderr.is_empty() {
306 output.push_str(&format!("\nError: {stderr}"));
307 }
308 }
309 output
310 }
311 UserContent::AppCommand { command, response } => {
312 if let Some(resp) = response {
313 let text = match resp {
314 CommandResponse::Text(msg) => msg.clone(),
315 CommandResponse::Compact(result) => match result {
316 CompactResult::Success(summary) => summary.clone(),
317 CompactResult::Cancelled => {
318 "Compact command cancelled.".to_string()
319 }
320 CompactResult::InsufficientMessages => {
321 "Not enough messages to compact (minimum 10 required)."
322 .to_string()
323 }
324 },
325 };
326 format!("/{}\n{}", command.as_command_str(), text)
327 } else {
328 format!("/{}", command.as_command_str())
329 }
330 }
331 })
332 .collect::<Vec<_>>()
333 .join("\n"),
334 MessageData::Assistant { content } => content
335 .iter()
336 .map(|c| match c {
337 AssistantContent::Text { text } => text.clone(),
338 AssistantContent::ToolCall { tool_call } => {
339 format!("[Tool Call: {}]", tool_call.name)
340 }
341 AssistantContent::Thought { thought } => {
342 format!("[Thought: {}]", thought.display_text())
343 }
344 })
345 .collect::<Vec<_>>()
346 .join("\n"),
347 MessageData::Tool { result, .. } => {
348 let result_type = match result {
350 ToolResult::Search(_) => "Search Result",
351 ToolResult::FileList(_) => "File List",
352 ToolResult::FileContent(_) => "File Content",
353 ToolResult::Edit(_) => "Edit Result",
354 ToolResult::Bash(_) => "Bash Result",
355 ToolResult::Glob(_) => "Glob Result",
356 ToolResult::TodoRead(_) => "Todo List",
357 ToolResult::TodoWrite(_) => "Todo Update",
358 ToolResult::Fetch(_) => "Fetch Result",
359 ToolResult::Agent(_) => "Agent Result",
360 ToolResult::External(_) => "External Tool Result",
361 ToolResult::Error(_) => "Error",
362 };
363 format!("[Tool Result: {result_type}]")
364 }
365 }
366 }
367}
368
369const SUMMARY_PROMPT: &str = r#"Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
370This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing development work without losing context.
371
372Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
373
3741. Chronologically analyze each message and section of the conversation. For each section thoroughly identify:
375 - The user's explicit requests and intents
376 - Your approach to addressing the user's requests
377 - Key decisions, technical concepts and code patterns
378 - Specific details like file names, full code snippets, function signatures, file edits, etc
3792. Double-check for technical accuracy and completeness, addressing each required element thoroughly.
380
381Your summary should include the following sections:
382
3831. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
3842. Key Technical Concepts: List all important technical concepts, technologies, and frameworks discussed.
3853. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
3864. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
3875. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
3886. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
3897. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests without confirming with the user first.
390 If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
391
392Here's an example of how your output should be structured:
393
394<example>
395<analysis>
396[Your thought process, ensuring all points are covered thoroughly and accurately]
397</analysis>
398
399<summary>
4001. Primary Request and Intent:
401 [Detailed description]
402
4032. Key Technical Concepts:
404 - [Concept 1]
405 - [Concept 2]
406 - [...]
407
4083. Files and Code Sections:
409 - [File Name 1]
410 - [Summary of why this file is important]
411 - [Summary of the changes made to this file, if any]
412 - [Important Code Snippet]
413 - [File Name 2]
414 - [Important Code Snippet]
415 - [...]
416
4174. Problem Solving:
418 [Description of solved problems and ongoing troubleshooting]
419
4205. Pending Tasks:
421 - [Task 1]
422 - [Task 2]
423 - [...]
424
4256. Current Work:
426 [Precise description of current work]
427
4287. Optional Next Step:
429 [Optional Next step to take]
430
431</summary>
432</example>
433
434Please provide your summary based on the conversation so far, following this structure and ensuring precision and thoroughness in your response.
435
436There may be additional summarization instructions provided in the included context. If so, remember to follow these instructions when creating the above summary. Examples of instructions include:
437<example>
438## Compact Instructions
439When summarizing the conversation focus on typescript code changes and also remember the mistakes you made and how you fixed them.
440</example>
441
442<example>
443# Summary instructions
444When you are using compact - please focus on test output and code changes. Include file reads verbatim.
445</example>"#;
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Conversation {
450 pub messages: Vec<Message>,
451 pub working_directory: PathBuf,
452 pub active_message_id: Option<String>,
455}
456
457impl Default for Conversation {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463impl Conversation {
464 pub fn new() -> Self {
465 Self {
466 messages: Vec::new(),
467 working_directory: PathBuf::new(),
468 active_message_id: None,
469 }
470 }
471
472 pub fn add_message(&mut self, message: Message) {
473 self.active_message_id = Some(message.id().to_string());
474 self.messages.push(message);
475 }
476
477 pub fn add_message_from_data(&mut self, message_data: MessageData) -> &Message {
478 debug!(target: "conversation::add_message", "Adding message: {:?}", message_data);
479 self.messages.push(Message {
480 data: message_data,
481 id: Message::generate_id("", Message::current_timestamp()),
482 timestamp: Message::current_timestamp(),
483 parent_message_id: self.active_message_id.clone(),
484 });
485 self.active_message_id = Some(self.messages.last().unwrap().id().to_string());
486 self.messages.last().unwrap()
487 }
488
489 pub fn clear(&mut self) {
490 debug!(target:"conversation::clear", "Clearing conversation");
491 self.messages.clear();
492 self.active_message_id = None;
493 }
494
495 pub fn find_tool_name_by_id(&self, tool_id: &str) -> Option<String> {
497 for message in self.messages.iter() {
498 if let MessageData::Assistant { content, .. } = &message.data {
499 for content_block in content {
500 if let AssistantContent::ToolCall { tool_call } = content_block {
501 if tool_call.id == tool_id {
502 return Some(tool_call.name.clone());
503 }
504 }
505 }
506 }
507 }
508 None
509 }
510
511 pub async fn compact(
513 &mut self,
514 api_client: &ApiClient,
515 model: ModelId,
516 token: CancellationToken,
517 ) -> crate::error::Result<CompactResult> {
518 let thread = self.get_active_thread();
520
521 if thread.len() < 10 {
523 return Ok(CompactResult::InsufficientMessages);
524 }
525
526 let mut prompt_messages: Vec<Message> = thread.into_iter().cloned().collect();
528 let last_msg_id = prompt_messages.last().map(|m| m.id().to_string());
529
530 prompt_messages.push(Message {
531 data: MessageData::User {
532 content: vec![UserContent::Text {
533 text: SUMMARY_PROMPT.to_string(),
534 }],
535 },
536 timestamp: Message::current_timestamp(),
537 id: Message::generate_id("user", Message::current_timestamp()),
538 parent_message_id: last_msg_id.clone(),
539 });
540
541 let summary = tokio::select! {
542 biased;
543 result = api_client.complete(
544 &model,
545 prompt_messages,
546 None,
547 None,
548 None,
549 token.clone(),
550 ) => result.map_err(crate::error::Error::Api)?,
551 _ = token.cancelled() => {
552 return Ok(CompactResult::Cancelled);
553 }
554 };
555
556 let summary_text = summary.extract_text();
557
558 let timestamp = Message::current_timestamp();
560 let summary_id = Message::generate_id("user", timestamp);
561
562 let summary_message = Message {
564 data: MessageData::User {
565 content: vec![UserContent::Text {
566 text: format!("[COMPACTED SUMMARY]\n\n{summary_text}"),
567 }],
568 },
569 timestamp,
570 id: summary_id.clone(),
571 parent_message_id: None,
572 };
573
574 self.messages.push(summary_message);
575
576 self.active_message_id = Some(summary_id);
578
579 Ok(CompactResult::Success(summary_text))
580 }
581
582 pub fn edit_message(
585 &mut self,
586 message_id: &str,
587 new_content: Vec<UserContent>,
588 ) -> Option<String> {
589 let message_to_edit = self.messages.iter().find(|m| m.id() == message_id)?;
591
592 if !matches!(&message_to_edit.data, MessageData::User { .. }) {
594 return None;
595 }
596
597 let parent_id = message_to_edit.parent_message_id().map(|s| s.to_string());
599
600 let new_message_id = Message::generate_id("user", Message::current_timestamp());
602 let edited_message = Message {
603 data: MessageData::User {
604 content: new_content,
605 },
606 timestamp: Message::current_timestamp(),
607 id: new_message_id.clone(),
608 parent_message_id: parent_id,
609 };
610
611 self.messages.push(edited_message);
613
614 self.active_message_id = Some(new_message_id.clone());
616
617 Some(new_message_id)
618 }
619
620 pub fn checkout(&mut self, message_id: &str) -> bool {
622 if self.messages.iter().any(|m| m.id() == message_id) {
624 self.active_message_id = Some(message_id.to_string());
625 true
626 } else {
627 false
628 }
629 }
630
631 pub fn get_active_thread(&self) -> Vec<&Message> {
633 if self.messages.is_empty() {
634 return Vec::new();
635 }
636
637 let head_id = if let Some(ref active_id) = self.active_message_id {
639 active_id.as_str()
641 } else {
642 self.messages.last().map(|m| m.id()).unwrap_or("")
644 };
645
646 let mut current_msg = self.messages.iter().find(|m| m.id() == head_id);
648 if current_msg.is_none() {
649 current_msg = self.messages.last();
651 }
652
653 let mut result = Vec::new();
654 let id_map: HashMap<&str, &Message> = self.messages.iter().map(|m| (m.id(), m)).collect();
655
656 while let Some(msg) = current_msg {
658 result.push(msg);
659
660 current_msg = if let Some(parent_id) = msg.parent_message_id() {
662 id_map.get(parent_id).copied()
663 } else {
664 None
665 };
666 }
667
668 result.reverse();
669
670 debug!(
671 "Active thread: [{}]",
672 result
673 .iter()
674 .map(|msg| msg.id())
675 .collect::<Vec<_>>()
676 .join(", ")
677 );
678 result
679 }
680
681 pub fn get_thread_messages(&self) -> Vec<&Message> {
684 self.get_active_thread()
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use crate::app::conversation::{
691 AssistantContent, Conversation, Message, MessageData, UserContent,
692 };
693
694 fn create_user_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
696 Message {
697 data: MessageData::User {
698 content: vec![UserContent::Text {
699 text: content.to_string(),
700 }],
701 },
702 timestamp: Message::current_timestamp(),
703 id: id.to_string(),
704 parent_message_id: parent_id.map(String::from),
705 }
706 }
707
708 fn create_assistant_message(id: &str, parent_id: Option<&str>, content: &str) -> Message {
710 Message {
711 data: MessageData::Assistant {
712 content: vec![AssistantContent::Text {
713 text: content.to_string(),
714 }],
715 },
716 timestamp: Message::current_timestamp(),
717 id: id.to_string(),
718 parent_message_id: parent_id.map(String::from),
719 }
720 }
721
722 #[test]
723 fn test_editing_message_in_the_middle_of_conversation() {
724 let mut conversation = Conversation::new();
725
726 let msg1 = create_user_message("msg1", None, "What is Rust?");
728 conversation.add_message(msg1.clone());
729
730 let msg2 =
731 create_assistant_message("msg2", Some("msg1"), "A systems programming language.");
732 conversation.add_message(msg2.clone());
733
734 let msg3 = create_user_message("msg3", Some("msg2"), "Is it fast?");
735 conversation.add_message(msg3.clone());
736
737 let msg4 = create_assistant_message("msg4", Some("msg3"), "Yes, it is very fast.");
738 conversation.add_message(msg4.clone());
739
740 let edited_id = conversation
742 .edit_message(
743 "msg1",
744 vec![UserContent::Text {
745 text: "What is Golang?".to_string(),
746 }],
747 )
748 .unwrap();
749
750 let messages_after_edit = conversation.get_thread_messages();
752 let message_ids_after_edit: Vec<&str> =
753 messages_after_edit.iter().map(|m| m.id()).collect();
754
755 assert_eq!(
756 message_ids_after_edit.len(),
757 1,
758 "Active thread should only show the edited message"
759 );
760 assert_eq!(message_ids_after_edit[0], edited_id.as_str());
761
762 assert!(conversation.messages.iter().any(|m| m.id() == "msg1"));
764 assert!(conversation.messages.iter().any(|m| m.id() == "msg2"));
765 assert!(conversation.messages.iter().any(|m| m.id() == "msg3"));
766 assert!(conversation.messages.iter().any(|m| m.id() == "msg4"));
767
768 let msg5 = create_assistant_message(
770 "msg5",
771 Some(&edited_id),
772 "A systems programming language from Google.",
773 );
774 conversation.add_message(msg5.clone());
775
776 let final_messages = conversation.get_thread_messages();
778 let final_message_ids: Vec<&str> = final_messages.iter().map(|m| m.id()).collect();
779
780 assert_eq!(
781 final_messages.len(),
782 2,
783 "Should have the edited message and the new response."
784 );
785 assert_eq!(final_message_ids[0], edited_id.as_str());
786 assert_eq!(final_message_ids[1], "msg5");
787 }
788
789 #[test]
790 fn test_get_thread_messages_after_edit() {
791 let mut conversation = Conversation::new();
792
793 let msg1 = create_user_message("msg1", None, "hello");
795 conversation.add_message(msg1.clone());
796
797 let msg2 = create_assistant_message("msg2", Some("msg1"), "world");
798 conversation.add_message(msg2.clone());
799
800 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
802 conversation.add_message(msg3_original.clone());
803
804 let edited_id = conversation
806 .edit_message(
807 "msg3_original",
808 vec![UserContent::Text {
809 text: "how are you".to_string(),
810 }],
811 )
812 .unwrap();
813
814 let msg4 = create_assistant_message("msg4", Some(&edited_id), "I am fine");
816 conversation.add_message(msg4.clone());
817
818 let thread_messages = conversation.get_thread_messages();
820
821 let thread_message_ids: Vec<&str> = thread_messages.iter().map(|m| m.id()).collect();
823
824 assert_eq!(
826 thread_message_ids.len(),
827 4,
828 "Should have 4 messages in the current thread"
829 );
830 assert!(thread_message_ids.contains(&"msg1"), "Should contain msg1");
831 assert!(thread_message_ids.contains(&"msg2"), "Should contain msg2");
832 assert!(
833 thread_message_ids.contains(&edited_id.as_str()),
834 "Should contain the edited message"
835 );
836 assert!(thread_message_ids.contains(&"msg4"), "Should contain msg4");
837
838 assert!(
840 conversation
841 .messages
842 .iter()
843 .any(|m| m.id() == "msg3_original"),
844 "Original message should still exist in conversation history"
845 );
846 }
847
848 #[test]
849 fn test_get_thread_messages_filters_other_branches() {
850 let mut conversation = Conversation::new();
851
852 let msg1 = create_user_message("msg1", None, "hi");
854 conversation.add_message(msg1.clone());
855
856 let msg2 = create_assistant_message("msg2", Some("msg1"), "Hello! How can I help?");
857 conversation.add_message(msg2.clone());
858
859 let msg3_original = create_user_message("msg3_original", Some("msg2"), "thanks");
861 conversation.add_message(msg3_original.clone());
862
863 let msg4_original =
864 create_assistant_message("msg4_original", Some("msg3_original"), "You're welcome!");
865 conversation.add_message(msg4_original.clone());
866
867 let edited_id = conversation
869 .edit_message(
870 "msg3_original",
871 vec![UserContent::Text {
872 text: "how are you".to_string(),
873 }],
874 )
875 .unwrap();
876
877 let msg4_new = create_assistant_message(
879 "msg4_new",
880 Some(&edited_id),
881 "I'm doing well, thanks for asking! Ready to help with any software engineering tasks you have.",
882 );
883 conversation.add_message(msg4_new.clone());
884
885 let msg5 = create_user_message("msg5", Some("msg4_new"), "what messages have I sent you?");
887 conversation.add_message(msg5.clone());
888
889 let thread_messages = conversation.get_thread_messages();
891
892 let user_messages: Vec<String> = thread_messages
894 .iter()
895 .filter(|m| matches!(m.data, MessageData::User { .. }))
896 .map(|m| m.extract_text())
897 .collect();
898
899 println!("User messages seen: {user_messages:?}");
900
901 assert_eq!(
903 user_messages.len(),
904 3,
905 "Should have exactly 3 user messages"
906 );
907 assert_eq!(user_messages[0], "hi", "First message should be 'hi'");
908 assert_eq!(
909 user_messages[1], "how are you",
910 "Second message should be 'how are you' (edited)"
911 );
912 assert_eq!(
913 user_messages[2], "what messages have I sent you?",
914 "Third message should be the question"
915 );
916
917 assert!(
919 !user_messages.contains(&"thanks".to_string()),
920 "Should NOT contain 'thanks' from the non-active branch"
921 );
922
923 assert!(
925 conversation
926 .messages
927 .iter()
928 .any(|m| m.id() == "msg3_original"),
929 "Original 'thanks' message should still exist in conversation history"
930 );
931 }
932
933 #[test]
934 fn test_checkout_branch() {
935 let mut conversation = Conversation::new();
936
937 let msg1 = create_user_message("msg1", None, "hello");
939 conversation.add_message(msg1.clone());
940
941 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi there");
942 conversation.add_message(msg2.clone());
943
944 let edited_id = conversation
946 .edit_message(
947 "msg1",
948 vec![UserContent::Text {
949 text: "goodbye".to_string(),
950 }],
951 )
952 .unwrap();
953
954 assert_eq!(conversation.active_message_id, Some(edited_id.clone()));
956 let thread = conversation.get_active_thread();
957 assert_eq!(thread.len(), 1);
958 assert_eq!(thread[0].id(), edited_id);
959
960 assert!(conversation.checkout("msg2"));
962 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
963
964 let thread = conversation.get_active_thread();
966 assert_eq!(thread.len(), 2);
967 assert_eq!(thread[0].id(), "msg1");
968 assert_eq!(thread[1].id(), "msg2");
969
970 assert!(!conversation.checkout("non-existent"));
972 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
973 }
974
975 #[test]
976 fn test_active_message_id_tracking() {
977 let mut conversation = Conversation::new();
978
979 assert_eq!(conversation.active_message_id, None);
981
982 let msg1 = create_user_message("msg1", None, "hello");
984 conversation.add_message(msg1);
985 assert_eq!(conversation.active_message_id, Some("msg1".to_string()));
986
987 let msg2 = create_assistant_message("msg2", Some("msg1"), "hi");
989 conversation.add_message(msg2);
990 assert_eq!(conversation.active_message_id, Some("msg2".to_string()));
991
992 let msg3 = create_user_message("msg3", Some("msg1"), "different question");
994 conversation.add_message(msg3);
995 assert_eq!(conversation.active_message_id, Some("msg3".to_string()));
996
997 let msg4 = create_user_message("msg4", Some("msg3"), "follow up");
999 conversation.add_message(msg4);
1000 assert_eq!(conversation.active_message_id, Some("msg4".to_string()));
1001 }
1002}