1use rig::completion::Message;
7use serde::{Deserialize, Serialize};
8
9pub const DEFAULT_COMPRESSION_THRESHOLD: f32 = 0.85;
11
12pub const COMPRESSION_PRESERVE_FRACTION: f32 = 0.3;
14
15const CHARS_PER_TOKEN: usize = 4;
17
18const DEFAULT_MAX_CONTEXT_TOKENS: usize = 128_000;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ConversationTurn {
24 pub user_message: String,
25 pub assistant_response: String,
26 pub tool_calls: Vec<ToolCallRecord>,
28 pub estimated_tokens: usize,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ToolCallRecord {
35 pub tool_name: String,
36 pub args_summary: String,
37 pub result_summary: String,
38}
39
40#[derive(Debug, Clone)]
42pub struct ConversationHistory {
43 turns: Vec<ConversationTurn>,
45 compressed_summary: Option<String>,
47 total_tokens: usize,
49 compression_threshold_tokens: usize,
51}
52
53impl Default for ConversationHistory {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl ConversationHistory {
60 pub fn new() -> Self {
61 let max_tokens = DEFAULT_MAX_CONTEXT_TOKENS;
62 Self {
63 turns: Vec::new(),
64 compressed_summary: None,
65 total_tokens: 0,
66 compression_threshold_tokens: (max_tokens as f32 * DEFAULT_COMPRESSION_THRESHOLD) as usize,
67 }
68 }
69
70 pub fn with_threshold(max_context_tokens: usize, threshold_fraction: f32) -> Self {
72 Self {
73 turns: Vec::new(),
74 compressed_summary: None,
75 total_tokens: 0,
76 compression_threshold_tokens: (max_context_tokens as f32 * threshold_fraction) as usize,
77 }
78 }
79
80 fn estimate_tokens(text: &str) -> usize {
82 text.len() / CHARS_PER_TOKEN
83 }
84
85 pub fn add_turn(&mut self, user_message: String, assistant_response: String, tool_calls: Vec<ToolCallRecord>) {
87 let turn_tokens = Self::estimate_tokens(&user_message)
88 + Self::estimate_tokens(&assistant_response)
89 + tool_calls.iter().map(|tc| {
90 Self::estimate_tokens(&tc.tool_name)
91 + Self::estimate_tokens(&tc.args_summary)
92 + Self::estimate_tokens(&tc.result_summary)
93 }).sum::<usize>();
94
95 self.turns.push(ConversationTurn {
96 user_message,
97 assistant_response,
98 tool_calls,
99 estimated_tokens: turn_tokens,
100 });
101 self.total_tokens += turn_tokens;
102 }
103
104 pub fn needs_compaction(&self) -> bool {
106 self.total_tokens > self.compression_threshold_tokens
107 }
108
109 pub fn token_count(&self) -> usize {
111 self.total_tokens
112 }
113
114 pub fn turn_count(&self) -> usize {
116 self.turns.len()
117 }
118
119 pub fn clear(&mut self) {
121 self.turns.clear();
122 self.compressed_summary = None;
123 self.total_tokens = 0;
124 }
125
126 pub fn compact(&mut self) -> Option<String> {
129 if self.turns.len() < 2 {
130 return None; }
132
133 let preserve_count = ((self.turns.len() as f32) * COMPRESSION_PRESERVE_FRACTION).ceil() as usize;
135 let preserve_count = preserve_count.max(1); let split_point = self.turns.len().saturating_sub(preserve_count);
137
138 if split_point == 0 {
139 return None; }
141
142 let turns_to_compress = &self.turns[..split_point];
144 let summary = self.create_summary(turns_to_compress);
145
146 let new_summary = if let Some(existing) = &self.compressed_summary {
148 format!("{}\n\n{}", existing, summary)
149 } else {
150 summary.clone()
151 };
152 self.compressed_summary = Some(new_summary);
153
154 let preserved_turns: Vec<_> = self.turns[split_point..].to_vec();
156 self.turns = preserved_turns;
157
158 self.total_tokens = Self::estimate_tokens(self.compressed_summary.as_deref().unwrap_or(""))
160 + self.turns.iter().map(|t| t.estimated_tokens).sum::<usize>();
161
162 Some(summary)
163 }
164
165 fn create_summary(&self, turns: &[ConversationTurn]) -> String {
167 let mut summary_parts = Vec::new();
168
169 for (i, turn) in turns.iter().enumerate() {
170 let mut turn_summary = format!(
171 "Turn {}: User asked about: {}",
172 i + 1,
173 Self::truncate_text(&turn.user_message, 100)
174 );
175
176 if !turn.tool_calls.is_empty() {
177 let tool_names: Vec<_> = turn.tool_calls.iter()
178 .map(|tc| tc.tool_name.as_str())
179 .collect();
180 turn_summary.push_str(&format!(". Tools used: {}", tool_names.join(", ")));
181 }
182
183 turn_summary.push_str(&format!(
184 ". Response summary: {}",
185 Self::truncate_text(&turn.assistant_response, 200)
186 ));
187
188 summary_parts.push(turn_summary);
189 }
190
191 format!(
192 "Previous conversation summary ({} turns compressed):\n{}",
193 turns.len(),
194 summary_parts.join("\n")
195 )
196 }
197
198 fn truncate_text(text: &str, max_len: usize) -> String {
200 if text.len() <= max_len {
201 text.to_string()
202 } else {
203 format!("{}...", &text[..max_len.saturating_sub(3)])
204 }
205 }
206
207 pub fn to_messages(&self) -> Vec<Message> {
209 use rig::completion::message::{Text, UserContent, AssistantContent};
210 use rig::OneOrMany;
211
212 let mut messages = Vec::new();
213
214 if let Some(summary) = &self.compressed_summary {
216 messages.push(Message::User {
218 content: OneOrMany::one(UserContent::Text(Text {
219 text: format!("[Previous conversation context]\n{}", summary),
220 })),
221 });
222 messages.push(Message::Assistant {
223 id: None,
224 content: OneOrMany::one(AssistantContent::Text(Text {
225 text: "I understand the previous context. How can I help you continue?".to_string(),
226 })),
227 });
228 }
229
230 for turn in &self.turns {
232 messages.push(Message::User {
234 content: OneOrMany::one(UserContent::Text(Text {
235 text: turn.user_message.clone(),
236 })),
237 });
238
239 messages.push(Message::Assistant {
242 id: None,
243 content: OneOrMany::one(AssistantContent::Text(Text {
244 text: turn.assistant_response.clone(),
245 })),
246 });
247 }
248
249 messages
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.turns.is_empty() && self.compressed_summary.is_none()
255 }
256
257 pub fn status(&self) -> String {
259 let compressed_info = if self.compressed_summary.is_some() {
260 " (with compressed history)"
261 } else {
262 ""
263 };
264 format!(
265 "{} turns, ~{} tokens{}",
266 self.turns.len(),
267 self.total_tokens,
268 compressed_info
269 )
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_add_turn() {
279 let mut history = ConversationHistory::new();
280 history.add_turn(
281 "Hello".to_string(),
282 "Hi there!".to_string(),
283 vec![],
284 );
285 assert_eq!(history.turn_count(), 1);
286 assert!(!history.is_empty());
287 }
288
289 #[test]
290 fn test_compaction() {
291 let mut history = ConversationHistory::with_threshold(1000, 0.1); for i in 0..10 {
295 history.add_turn(
296 format!("Question {}", i),
297 format!("Answer {} with lots of detail to increase token count", i),
298 vec![ToolCallRecord {
299 tool_name: "analyze".to_string(),
300 args_summary: "path: .".to_string(),
301 result_summary: "Found rust project".to_string(),
302 }],
303 );
304 }
305
306 if history.needs_compaction() {
307 let summary = history.compact();
308 assert!(summary.is_some());
309 assert!(history.turn_count() < 10);
310 }
311 }
312
313 #[test]
314 fn test_to_messages() {
315 let mut history = ConversationHistory::new();
316 history.add_turn(
317 "What is this project?".to_string(),
318 "This is a Rust CLI tool.".to_string(),
319 vec![],
320 );
321
322 let messages = history.to_messages();
323 assert_eq!(messages.len(), 2); }
325
326 #[test]
327 fn test_clear() {
328 let mut history = ConversationHistory::new();
329 history.add_turn("Test".to_string(), "Response".to_string(), vec![]);
330 history.clear();
331 assert!(history.is_empty());
332 assert_eq!(history.token_count(), 0);
333 }
334}