1use crate::types::{ContentPart, Conversation, ConversationEntry, HistoryEntry, MessageRole};
2use chrono::{DateTime, Utc};
3
4pub struct ConversationQuery<'a> {
5 conversation: &'a Conversation,
6}
7
8impl<'a> ConversationQuery<'a> {
9 pub fn new(conversation: &'a Conversation) -> Self {
10 Self { conversation }
11 }
12
13 pub fn by_role(&self, role: MessageRole) -> Vec<&'a ConversationEntry> {
14 self.conversation
15 .entries
16 .iter()
17 .filter(|e| e.message.as_ref().map(|m| m.role == role).unwrap_or(false))
18 .collect()
19 }
20
21 pub fn by_type(&self, entry_type: &str) -> Vec<&'a ConversationEntry> {
22 self.conversation
23 .entries
24 .iter()
25 .filter(|e| e.entry_type == entry_type)
26 .collect()
27 }
28
29 pub fn by_time_range(
30 &self,
31 start: DateTime<Utc>,
32 end: DateTime<Utc>,
33 ) -> Vec<&'a ConversationEntry> {
34 self.conversation
35 .entries
36 .iter()
37 .filter(|e| {
38 if let Ok(timestamp) = e.timestamp.parse::<DateTime<Utc>>() {
39 timestamp >= start && timestamp <= end
40 } else {
41 false
42 }
43 })
44 .collect()
45 }
46
47 pub fn tool_uses_by_name(&self, tool_name: &str) -> Vec<&'a ConversationEntry> {
48 self.conversation
49 .entries
50 .iter()
51 .filter(|e| {
52 if let Some(message) = &e.message
53 && let Some(crate::types::MessageContent::Parts(parts)) = &message.content
54 {
55 return parts.iter().any(|p| {
56 if let ContentPart::ToolUse { name, .. } = p {
57 name == tool_name
58 } else {
59 false
60 }
61 });
62 }
63 false
64 })
65 .collect()
66 }
67
68 pub fn contains_text(&self, search: &str) -> Vec<&'a ConversationEntry> {
69 let search_lower = search.to_lowercase();
70 self.conversation
71 .entries
72 .iter()
73 .filter(|e| {
74 if let Some(message) = &e.message {
75 match &message.content {
76 Some(crate::types::MessageContent::Text(text)) => {
77 text.to_lowercase().contains(&search_lower)
78 }
79 Some(crate::types::MessageContent::Parts(parts)) => {
80 parts.iter().any(|p| match p {
81 ContentPart::Text { text } => {
82 text.to_lowercase().contains(&search_lower)
83 }
84 ContentPart::ToolResult { content, .. } => {
85 content.text().to_lowercase().contains(&search_lower)
86 }
87 _ => false,
88 })
89 }
90 None => false,
91 }
92 } else {
93 false
94 }
95 })
96 .collect()
97 }
98
99 pub fn errors(&self) -> Vec<&'a ConversationEntry> {
100 self.conversation
101 .entries
102 .iter()
103 .filter(|e| {
104 if let Some(message) = &e.message
105 && let Some(crate::types::MessageContent::Parts(parts)) = &message.content
106 {
107 return parts.iter().any(|p| {
108 if let ContentPart::ToolResult { is_error, .. } = p {
109 *is_error
110 } else {
111 false
112 }
113 });
114 }
115 false
116 })
117 .collect()
118 }
119}
120
121pub struct HistoryQuery<'a> {
122 history: &'a [HistoryEntry],
123}
124
125impl<'a> HistoryQuery<'a> {
126 pub fn new(history: &'a [HistoryEntry]) -> Self {
127 Self { history }
128 }
129
130 pub fn by_project(&self, project: &str) -> Vec<&'a HistoryEntry> {
131 self.history
132 .iter()
133 .filter(|e| e.project.as_deref() == Some(project))
134 .collect()
135 }
136
137 pub fn by_session(&self, session_id: &str) -> Vec<&'a HistoryEntry> {
138 self.history
139 .iter()
140 .filter(|e| e.session_id.as_deref() == Some(session_id))
141 .collect()
142 }
143
144 pub fn by_time_range(&self, start: i64, end: i64) -> Vec<&'a HistoryEntry> {
145 self.history
146 .iter()
147 .filter(|e| e.timestamp >= start && e.timestamp <= end)
148 .collect()
149 }
150
151 pub fn contains_text(&self, search: &str) -> Vec<&'a HistoryEntry> {
152 let search_lower = search.to_lowercase();
153 self.history
154 .iter()
155 .filter(|e| e.display.to_lowercase().contains(&search_lower))
156 .collect()
157 }
158
159 pub fn recent(&self, count: usize) -> Vec<&'a HistoryEntry> {
160 let mut sorted: Vec<&'a HistoryEntry> = self.history.iter().collect();
161 sorted.sort_by_key(|e| std::cmp::Reverse(e.timestamp));
162 sorted.into_iter().take(count).collect()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::types::{Conversation, ConversationEntry, Message, MessageContent};
170
171 fn create_test_conversation() -> Conversation {
172 let mut conv = Conversation::new("test".to_string());
173
174 let user_entry = ConversationEntry {
175 parent_uuid: None,
176 is_sidechain: false,
177 entry_type: "user".to_string(),
178 uuid: "1".to_string(),
179 timestamp: "2024-01-01T00:00:00Z".to_string(),
180 session_id: Some("test".to_string()),
181 message: Some(Message {
182 role: MessageRole::User,
183 content: Some(MessageContent::Text("Hello world".to_string())),
184 model: None,
185 id: None,
186 message_type: None,
187 stop_reason: None,
188 stop_sequence: None,
189 usage: None,
190 }),
191 cwd: None,
192 git_branch: None,
193 version: None,
194 user_type: None,
195 request_id: None,
196 tool_use_result: None,
197 snapshot: None,
198 message_id: None,
199 extra: Default::default(),
200 };
201
202 let assistant_entry = ConversationEntry {
203 parent_uuid: Some("1".to_string()),
204 is_sidechain: false,
205 entry_type: "assistant".to_string(),
206 uuid: "2".to_string(),
207 timestamp: "2024-01-01T00:00:01Z".to_string(),
208 session_id: Some("test".to_string()),
209 message: Some(Message {
210 role: MessageRole::Assistant,
211 content: Some(MessageContent::Text("Hi there".to_string())),
212 model: None,
213 id: None,
214 message_type: None,
215 stop_reason: None,
216 stop_sequence: None,
217 usage: None,
218 }),
219 cwd: None,
220 git_branch: None,
221 version: None,
222 user_type: None,
223 request_id: None,
224 tool_use_result: None,
225 snapshot: None,
226 message_id: None,
227 extra: Default::default(),
228 };
229
230 conv.add_entry(user_entry);
231 conv.add_entry(assistant_entry);
232 conv
233 }
234
235 #[test]
236 fn test_query_by_role() {
237 let conv = create_test_conversation();
238 let query = ConversationQuery::new(&conv);
239
240 let user_msgs = query.by_role(MessageRole::User);
241 assert_eq!(user_msgs.len(), 1);
242
243 let assistant_msgs = query.by_role(MessageRole::Assistant);
244 assert_eq!(assistant_msgs.len(), 1);
245 }
246
247 #[test]
248 fn test_query_contains_text() {
249 let conv = create_test_conversation();
250 let query = ConversationQuery::new(&conv);
251
252 let results = query.contains_text("Hello");
253 assert_eq!(results.len(), 1);
254 assert_eq!(results[0].uuid, "1");
255
256 let results = query.contains_text("Hi");
257 assert_eq!(results.len(), 1);
258 assert_eq!(results[0].uuid, "2");
259 }
260
261 #[test]
262 fn test_query_by_type() {
263 let conv = create_test_conversation();
264 let query = ConversationQuery::new(&conv);
265
266 let users = query.by_type("user");
267 assert_eq!(users.len(), 1);
268 assert_eq!(users[0].uuid, "1");
269
270 let assistants = query.by_type("assistant");
271 assert_eq!(assistants.len(), 1);
272 assert_eq!(assistants[0].uuid, "2");
273 }
274
275 #[test]
276 fn test_query_by_time_range() {
277 let conv = create_test_conversation();
278 let query = ConversationQuery::new(&conv);
279
280 let start = "2024-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
281 let end = "2024-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
282 let results = query.by_time_range(start, end);
283 assert_eq!(results.len(), 1);
284 assert_eq!(results[0].uuid, "1");
285 }
286
287 #[test]
288 fn test_query_by_time_range_all() {
289 let conv = create_test_conversation();
290 let query = ConversationQuery::new(&conv);
291
292 let start = "2023-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
293 let end = "2025-01-01T00:00:00Z".parse::<DateTime<Utc>>().unwrap();
294 let results = query.by_time_range(start, end);
295 assert_eq!(results.len(), 2);
296 }
297
298 #[test]
299 fn test_query_tool_uses_by_name() {
300 let mut conv = Conversation::new("test".to_string());
302 let entry = ConversationEntry {
303 parent_uuid: None,
304 is_sidechain: false,
305 entry_type: "assistant".to_string(),
306 uuid: "3".to_string(),
307 timestamp: "2024-01-01T00:00:02Z".to_string(),
308 session_id: Some("test".to_string()),
309 message: Some(Message {
310 role: MessageRole::Assistant,
311 content: Some(MessageContent::Parts(vec![ContentPart::ToolUse {
312 id: "t1".to_string(),
313 name: "Read".to_string(),
314 input: serde_json::Value::Object(Default::default()),
315 }])),
316 model: None,
317 id: None,
318 message_type: None,
319 stop_reason: None,
320 stop_sequence: None,
321 usage: None,
322 }),
323 cwd: None,
324 git_branch: None,
325 version: None,
326 user_type: None,
327 request_id: None,
328 tool_use_result: None,
329 snapshot: None,
330 message_id: None,
331 extra: Default::default(),
332 };
333 conv.add_entry(entry);
334
335 let query = ConversationQuery::new(&conv);
336 let reads = query.tool_uses_by_name("Read");
337 assert_eq!(reads.len(), 1);
338
339 let writes = query.tool_uses_by_name("Write");
340 assert!(writes.is_empty());
341 }
342
343 #[test]
344 fn test_query_errors() {
345 let mut conv = Conversation::new("test".to_string());
346 let entry = ConversationEntry {
347 parent_uuid: None,
348 is_sidechain: false,
349 entry_type: "assistant".to_string(),
350 uuid: "e1".to_string(),
351 timestamp: "2024-01-01T00:00:00Z".to_string(),
352 session_id: Some("test".to_string()),
353 message: Some(Message {
354 role: MessageRole::Assistant,
355 content: Some(MessageContent::Parts(vec![ContentPart::ToolResult {
356 tool_use_id: "t1".to_string(),
357 content: crate::types::ToolResultContent::Text("failed!".to_string()),
358 is_error: true,
359 }])),
360 model: None,
361 id: None,
362 message_type: None,
363 stop_reason: None,
364 stop_sequence: None,
365 usage: None,
366 }),
367 cwd: None,
368 git_branch: None,
369 version: None,
370 user_type: None,
371 request_id: None,
372 tool_use_result: None,
373 snapshot: None,
374 message_id: None,
375 extra: Default::default(),
376 };
377 conv.add_entry(entry);
378
379 let query = ConversationQuery::new(&conv);
380 let errors = query.errors();
381 assert_eq!(errors.len(), 1);
382 }
383
384 #[test]
385 fn test_query_errors_empty() {
386 let conv = create_test_conversation();
387 let query = ConversationQuery::new(&conv);
388 assert!(query.errors().is_empty());
389 }
390
391 #[test]
392 fn test_query_contains_text_case_insensitive() {
393 let conv = create_test_conversation();
394 let query = ConversationQuery::new(&conv);
395
396 let results = query.contains_text("hello");
397 assert_eq!(results.len(), 1);
398 }
399
400 fn create_test_history() -> Vec<HistoryEntry> {
403 vec![
404 HistoryEntry {
405 display: "fix bug in auth".to_string(),
406 pasted_contents: Default::default(),
407 timestamp: 1000,
408 project: Some("/project/a".to_string()),
409 session_id: Some("session-1".to_string()),
410 },
411 HistoryEntry {
412 display: "add feature X".to_string(),
413 pasted_contents: Default::default(),
414 timestamp: 2000,
415 project: Some("/project/b".to_string()),
416 session_id: Some("session-2".to_string()),
417 },
418 HistoryEntry {
419 display: "refactor auth module".to_string(),
420 pasted_contents: Default::default(),
421 timestamp: 3000,
422 project: Some("/project/a".to_string()),
423 session_id: Some("session-1".to_string()),
424 },
425 ]
426 }
427
428 #[test]
429 fn test_history_by_project() {
430 let history = create_test_history();
431 let query = HistoryQuery::new(&history);
432
433 let results = query.by_project("/project/a");
434 assert_eq!(results.len(), 2);
435 }
436
437 #[test]
438 fn test_history_by_session() {
439 let history = create_test_history();
440 let query = HistoryQuery::new(&history);
441
442 let results = query.by_session("session-2");
443 assert_eq!(results.len(), 1);
444 assert_eq!(results[0].display, "add feature X");
445 }
446
447 #[test]
448 fn test_history_by_time_range() {
449 let history = create_test_history();
450 let query = HistoryQuery::new(&history);
451
452 let results = query.by_time_range(1500, 2500);
453 assert_eq!(results.len(), 1);
454 assert_eq!(results[0].timestamp, 2000);
455 }
456
457 #[test]
458 fn test_history_contains_text() {
459 let history = create_test_history();
460 let query = HistoryQuery::new(&history);
461
462 let results = query.contains_text("auth");
463 assert_eq!(results.len(), 2);
464 }
465
466 #[test]
467 fn test_history_contains_text_case_insensitive() {
468 let history = create_test_history();
469 let query = HistoryQuery::new(&history);
470
471 let results = query.contains_text("AUTH");
472 assert_eq!(results.len(), 2);
473 }
474
475 #[test]
476 fn test_history_recent() {
477 let history = create_test_history();
478 let query = HistoryQuery::new(&history);
479
480 let results = query.recent(2);
481 assert_eq!(results.len(), 2);
482 assert_eq!(results[0].timestamp, 3000);
484 assert_eq!(results[1].timestamp, 2000);
485 }
486
487 #[test]
488 fn test_history_recent_more_than_available() {
489 let history = create_test_history();
490 let query = HistoryQuery::new(&history);
491
492 let results = query.recent(10);
493 assert_eq!(results.len(), 3);
494 }
495}