1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10fn parse_role(s: &str) -> Role {
11 match s {
12 "assistant" => Role::Assistant,
13 "system" => Role::System,
14 _ => Role::User,
15 }
16}
17
18#[must_use]
19pub fn role_str(role: Role) -> &'static str {
20 match role {
21 Role::System => "system",
22 Role::User => "user",
23 Role::Assistant => "assistant",
24 }
25}
26
27impl SqliteStore {
28 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
34 let row: (ConversationId,) =
35 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
36 .fetch_one(&self.pool)
37 .await?;
38 Ok(row.0)
39 }
40
41 pub async fn save_message(
47 &self,
48 conversation_id: ConversationId,
49 role: &str,
50 content: &str,
51 ) -> Result<MessageId, MemoryError> {
52 self.save_message_with_parts(conversation_id, role, content, "[]")
53 .await
54 }
55
56 pub async fn save_message_with_parts(
62 &self,
63 conversation_id: ConversationId,
64 role: &str,
65 content: &str,
66 parts_json: &str,
67 ) -> Result<MessageId, MemoryError> {
68 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
69 .await
70 }
71
72 pub async fn save_message_with_metadata(
78 &self,
79 conversation_id: ConversationId,
80 role: &str,
81 content: &str,
82 parts_json: &str,
83 agent_visible: bool,
84 user_visible: bool,
85 ) -> Result<MessageId, MemoryError> {
86 let row: (MessageId,) = sqlx::query_as(
87 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
88 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
89 )
90 .bind(conversation_id)
91 .bind(role)
92 .bind(content)
93 .bind(parts_json)
94 .bind(i64::from(agent_visible))
95 .bind(i64::from(user_visible))
96 .fetch_one(&self.pool)
97 .await?;
98 Ok(row.0)
99 }
100
101 pub async fn load_history(
107 &self,
108 conversation_id: ConversationId,
109 limit: u32,
110 ) -> Result<Vec<Message>, MemoryError> {
111 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
112 "SELECT role, content, parts, agent_visible, user_visible FROM (\
113 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
114 WHERE conversation_id = ? \
115 ORDER BY id DESC \
116 LIMIT ?\
117 ) ORDER BY id ASC",
118 )
119 .bind(conversation_id)
120 .bind(limit)
121 .fetch_all(&self.pool)
122 .await?;
123
124 let messages = rows
125 .into_iter()
126 .map(
127 |(role_str, content, parts_json, agent_visible, user_visible)| {
128 let parts: Vec<MessagePart> =
129 serde_json::from_str(&parts_json).unwrap_or_default();
130 Message {
131 role: parse_role(&role_str),
132 content,
133 parts,
134 metadata: MessageMetadata {
135 agent_visible: agent_visible != 0,
136 user_visible: user_visible != 0,
137 compacted_at: None,
138 },
139 }
140 },
141 )
142 .collect();
143 Ok(messages)
144 }
145
146 pub async fn load_history_filtered(
154 &self,
155 conversation_id: ConversationId,
156 limit: u32,
157 agent_visible: Option<bool>,
158 user_visible: Option<bool>,
159 ) -> Result<Vec<Message>, MemoryError> {
160 let av = agent_visible.map(i64::from);
161 let uv = user_visible.map(i64::from);
162
163 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
164 "SELECT role, content, parts, agent_visible, user_visible FROM (\
165 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
166 WHERE conversation_id = ? \
167 AND (? IS NULL OR agent_visible = ?) \
168 AND (? IS NULL OR user_visible = ?) \
169 ORDER BY id DESC \
170 LIMIT ?\
171 ) ORDER BY id ASC",
172 )
173 .bind(conversation_id)
174 .bind(av)
175 .bind(av)
176 .bind(uv)
177 .bind(uv)
178 .bind(limit)
179 .fetch_all(&self.pool)
180 .await?;
181
182 let messages = rows
183 .into_iter()
184 .map(
185 |(role_str, content, parts_json, agent_visible, user_visible)| {
186 let parts: Vec<MessagePart> =
187 serde_json::from_str(&parts_json).unwrap_or_default();
188 Message {
189 role: parse_role(&role_str),
190 content,
191 parts,
192 metadata: MessageMetadata {
193 agent_visible: agent_visible != 0,
194 user_visible: user_visible != 0,
195 compacted_at: None,
196 },
197 }
198 },
199 )
200 .collect();
201 Ok(messages)
202 }
203
204 pub async fn replace_conversation(
216 &self,
217 conversation_id: ConversationId,
218 compacted_range: std::ops::RangeInclusive<MessageId>,
219 summary_role: &str,
220 summary_content: &str,
221 ) -> Result<MessageId, MemoryError> {
222 let now = {
223 let secs = std::time::SystemTime::now()
224 .duration_since(std::time::UNIX_EPOCH)
225 .unwrap_or_default()
226 .as_secs();
227 format!("{secs}")
228 };
229 let start_id = compacted_range.start().0;
230 let end_id = compacted_range.end().0;
231
232 let mut tx = self.pool.begin().await?;
233
234 sqlx::query(
235 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
236 WHERE conversation_id = ? AND id >= ? AND id <= ?",
237 )
238 .bind(&now)
239 .bind(conversation_id)
240 .bind(start_id)
241 .bind(end_id)
242 .execute(&mut *tx)
243 .await?;
244
245 let row: (MessageId,) = sqlx::query_as(
246 "INSERT INTO messages \
247 (conversation_id, role, content, parts, agent_visible, user_visible) \
248 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
249 )
250 .bind(conversation_id)
251 .bind(summary_role)
252 .bind(summary_content)
253 .fetch_one(&mut *tx)
254 .await?;
255
256 tx.commit().await?;
257
258 Ok(row.0)
259 }
260
261 pub async fn oldest_message_ids(
267 &self,
268 conversation_id: ConversationId,
269 n: u32,
270 ) -> Result<Vec<MessageId>, MemoryError> {
271 let rows: Vec<(MessageId,)> = sqlx::query_as(
272 "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
273 )
274 .bind(conversation_id)
275 .bind(n)
276 .fetch_all(&self.pool)
277 .await?;
278 Ok(rows.into_iter().map(|r| r.0).collect())
279 }
280
281 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
287 let row: Option<(ConversationId,)> =
288 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
289 .fetch_optional(&self.pool)
290 .await?;
291 Ok(row.map(|r| r.0))
292 }
293
294 pub async fn message_by_id(
300 &self,
301 message_id: MessageId,
302 ) -> Result<Option<Message>, MemoryError> {
303 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
304 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
305 )
306 .bind(message_id)
307 .fetch_optional(&self.pool)
308 .await?;
309
310 Ok(row.map(
311 |(role_str, content, parts_json, agent_visible, user_visible)| {
312 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
313 Message {
314 role: parse_role(&role_str),
315 content,
316 parts,
317 metadata: MessageMetadata {
318 agent_visible: agent_visible != 0,
319 user_visible: user_visible != 0,
320 compacted_at: None,
321 },
322 }
323 },
324 ))
325 }
326
327 pub async fn messages_by_ids(
333 &self,
334 ids: &[MessageId],
335 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
336 if ids.is_empty() {
337 return Ok(Vec::new());
338 }
339
340 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
341
342 let query = format!(
343 "SELECT id, role, content, parts FROM messages \
344 WHERE id IN ({placeholders}) AND agent_visible = 1"
345 );
346 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
347 for &id in ids {
348 q = q.bind(id);
349 }
350
351 let rows = q.fetch_all(&self.pool).await?;
352
353 Ok(rows
354 .into_iter()
355 .map(|(id, role_str, content, parts_json)| {
356 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
357 (
358 id,
359 Message {
360 role: parse_role(&role_str),
361 content,
362 parts,
363 metadata: MessageMetadata::default(),
364 },
365 )
366 })
367 .collect())
368 }
369
370 pub async fn unembedded_message_ids(
376 &self,
377 limit: Option<usize>,
378 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
379 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
380
381 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
382 "SELECT m.id, m.conversation_id, m.role, m.content \
383 FROM messages m \
384 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
385 WHERE em.id IS NULL \
386 ORDER BY m.id ASC \
387 LIMIT ?",
388 )
389 .bind(effective_limit)
390 .fetch_all(&self.pool)
391 .await?;
392
393 Ok(rows)
394 }
395
396 pub async fn count_messages(
402 &self,
403 conversation_id: ConversationId,
404 ) -> Result<i64, MemoryError> {
405 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
406 .bind(conversation_id)
407 .fetch_one(&self.pool)
408 .await?;
409 Ok(row.0)
410 }
411
412 pub async fn count_messages_after(
418 &self,
419 conversation_id: ConversationId,
420 after_id: MessageId,
421 ) -> Result<i64, MemoryError> {
422 let row: (i64,) =
423 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
424 .bind(conversation_id)
425 .bind(after_id)
426 .fetch_one(&self.pool)
427 .await?;
428 Ok(row.0)
429 }
430
431 pub async fn keyword_search(
440 &self,
441 query: &str,
442 limit: usize,
443 conversation_id: Option<ConversationId>,
444 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
445 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
446
447 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
448 sqlx::query_as(
449 "SELECT m.id, -rank AS score \
450 FROM messages_fts f \
451 JOIN messages m ON m.id = f.rowid \
452 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
453 ORDER BY rank \
454 LIMIT ?",
455 )
456 .bind(query)
457 .bind(cid)
458 .bind(effective_limit)
459 .fetch_all(&self.pool)
460 .await?
461 } else {
462 sqlx::query_as(
463 "SELECT m.id, -rank AS score \
464 FROM messages_fts f \
465 JOIN messages m ON m.id = f.rowid \
466 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
467 ORDER BY rank \
468 LIMIT ?",
469 )
470 .bind(query)
471 .bind(effective_limit)
472 .fetch_all(&self.pool)
473 .await?
474 };
475
476 Ok(rows)
477 }
478
479 pub async fn message_timestamps(
487 &self,
488 ids: &[MessageId],
489 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
490 if ids.is_empty() {
491 return Ok(std::collections::HashMap::new());
492 }
493
494 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
495 let query = format!(
496 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
497 FROM messages WHERE id IN ({placeholders})"
498 );
499 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
500 for &id in ids {
501 q = q.bind(id);
502 }
503
504 let rows = q.fetch_all(&self.pool).await?;
505 Ok(rows.into_iter().collect())
506 }
507
508 pub async fn load_messages_range(
514 &self,
515 conversation_id: ConversationId,
516 after_message_id: MessageId,
517 limit: usize,
518 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
519 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
520
521 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
522 "SELECT id, role, content FROM messages \
523 WHERE conversation_id = ? AND id > ? \
524 ORDER BY id ASC LIMIT ?",
525 )
526 .bind(conversation_id)
527 .bind(after_message_id)
528 .bind(effective_limit)
529 .fetch_all(&self.pool)
530 .await?;
531
532 Ok(rows)
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 async fn test_store() -> SqliteStore {
541 SqliteStore::new(":memory:").await.unwrap()
542 }
543
544 #[tokio::test]
545 async fn create_conversation_returns_id() {
546 let store = test_store().await;
547 let id1 = store.create_conversation().await.unwrap();
548 let id2 = store.create_conversation().await.unwrap();
549 assert_eq!(id1, ConversationId(1));
550 assert_eq!(id2, ConversationId(2));
551 }
552
553 #[tokio::test]
554 async fn save_and_load_messages() {
555 let store = test_store().await;
556 let cid = store.create_conversation().await.unwrap();
557
558 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
559 let msg_id2 = store
560 .save_message(cid, "assistant", "hi there")
561 .await
562 .unwrap();
563
564 assert_eq!(msg_id1, MessageId(1));
565 assert_eq!(msg_id2, MessageId(2));
566
567 let history = store.load_history(cid, 50).await.unwrap();
568 assert_eq!(history.len(), 2);
569 assert_eq!(history[0].role, Role::User);
570 assert_eq!(history[0].content, "hello");
571 assert_eq!(history[1].role, Role::Assistant);
572 assert_eq!(history[1].content, "hi there");
573 }
574
575 #[tokio::test]
576 async fn load_history_respects_limit() {
577 let store = test_store().await;
578 let cid = store.create_conversation().await.unwrap();
579
580 for i in 0..10 {
581 store
582 .save_message(cid, "user", &format!("msg {i}"))
583 .await
584 .unwrap();
585 }
586
587 let history = store.load_history(cid, 3).await.unwrap();
588 assert_eq!(history.len(), 3);
589 assert_eq!(history[0].content, "msg 7");
590 assert_eq!(history[1].content, "msg 8");
591 assert_eq!(history[2].content, "msg 9");
592 }
593
594 #[tokio::test]
595 async fn latest_conversation_id_empty() {
596 let store = test_store().await;
597 assert!(store.latest_conversation_id().await.unwrap().is_none());
598 }
599
600 #[tokio::test]
601 async fn latest_conversation_id_returns_newest() {
602 let store = test_store().await;
603 store.create_conversation().await.unwrap();
604 let id2 = store.create_conversation().await.unwrap();
605 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
606 }
607
608 #[tokio::test]
609 async fn messages_isolated_per_conversation() {
610 let store = test_store().await;
611 let cid1 = store.create_conversation().await.unwrap();
612 let cid2 = store.create_conversation().await.unwrap();
613
614 store.save_message(cid1, "user", "conv1").await.unwrap();
615 store.save_message(cid2, "user", "conv2").await.unwrap();
616
617 let h1 = store.load_history(cid1, 50).await.unwrap();
618 let h2 = store.load_history(cid2, 50).await.unwrap();
619 assert_eq!(h1.len(), 1);
620 assert_eq!(h1[0].content, "conv1");
621 assert_eq!(h2.len(), 1);
622 assert_eq!(h2[0].content, "conv2");
623 }
624
625 #[tokio::test]
626 async fn pool_accessor_returns_valid_pool() {
627 let store = test_store().await;
628 let pool = store.pool();
629 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
630 assert_eq!(row.0, 1);
631 }
632
633 #[tokio::test]
634 async fn embeddings_metadata_table_exists() {
635 let store = test_store().await;
636 let result: (i64,) = sqlx::query_as(
637 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
638 )
639 .fetch_one(store.pool())
640 .await
641 .unwrap();
642 assert_eq!(result.0, 1);
643 }
644
645 #[tokio::test]
646 async fn cascade_delete_removes_embeddings_metadata() {
647 let store = test_store().await;
648 let pool = store.pool();
649
650 let cid = store.create_conversation().await.unwrap();
651 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
652
653 let point_id = uuid::Uuid::new_v4().to_string();
654 sqlx::query(
655 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
656 VALUES (?, ?, ?)",
657 )
658 .bind(msg_id)
659 .bind(&point_id)
660 .bind(768_i64)
661 .execute(pool)
662 .await
663 .unwrap();
664
665 let before: (i64,) =
666 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
667 .bind(msg_id)
668 .fetch_one(pool)
669 .await
670 .unwrap();
671 assert_eq!(before.0, 1);
672
673 sqlx::query("DELETE FROM messages WHERE id = ?")
674 .bind(msg_id)
675 .execute(pool)
676 .await
677 .unwrap();
678
679 let after: (i64,) =
680 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
681 .bind(msg_id)
682 .fetch_one(pool)
683 .await
684 .unwrap();
685 assert_eq!(after.0, 0);
686 }
687
688 #[tokio::test]
689 async fn messages_by_ids_batch_fetch() {
690 let store = test_store().await;
691 let cid = store.create_conversation().await.unwrap();
692 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
693 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
694 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
695
696 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
697 assert_eq!(results.len(), 2);
698 assert_eq!(results[0].0, id1);
699 assert_eq!(results[0].1.content, "hello");
700 assert_eq!(results[1].0, id2);
701 assert_eq!(results[1].1.content, "hi");
702 }
703
704 #[tokio::test]
705 async fn messages_by_ids_empty_input() {
706 let store = test_store().await;
707 let results = store.messages_by_ids(&[]).await.unwrap();
708 assert!(results.is_empty());
709 }
710
711 #[tokio::test]
712 async fn messages_by_ids_nonexistent() {
713 let store = test_store().await;
714 let results = store
715 .messages_by_ids(&[MessageId(999), MessageId(1000)])
716 .await
717 .unwrap();
718 assert!(results.is_empty());
719 }
720
721 #[tokio::test]
722 async fn message_by_id_fetches_existing() {
723 let store = test_store().await;
724 let cid = store.create_conversation().await.unwrap();
725 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
726
727 let msg = store.message_by_id(msg_id).await.unwrap();
728 assert!(msg.is_some());
729 let msg = msg.unwrap();
730 assert_eq!(msg.role, Role::User);
731 assert_eq!(msg.content, "hello");
732 }
733
734 #[tokio::test]
735 async fn message_by_id_returns_none_for_nonexistent() {
736 let store = test_store().await;
737 let msg = store.message_by_id(MessageId(999)).await.unwrap();
738 assert!(msg.is_none());
739 }
740
741 #[tokio::test]
742 async fn unembedded_message_ids_returns_all_when_none_embedded() {
743 let store = test_store().await;
744 let cid = store.create_conversation().await.unwrap();
745
746 store.save_message(cid, "user", "msg1").await.unwrap();
747 store.save_message(cid, "assistant", "msg2").await.unwrap();
748
749 let unembedded = store.unembedded_message_ids(None).await.unwrap();
750 assert_eq!(unembedded.len(), 2);
751 assert_eq!(unembedded[0].3, "msg1");
752 assert_eq!(unembedded[1].3, "msg2");
753 }
754
755 #[tokio::test]
756 async fn unembedded_message_ids_excludes_embedded() {
757 let store = test_store().await;
758 let pool = store.pool();
759 let cid = store.create_conversation().await.unwrap();
760
761 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
762 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
763
764 let point_id = uuid::Uuid::new_v4().to_string();
765 sqlx::query(
766 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
767 VALUES (?, ?, ?)",
768 )
769 .bind(msg_id1)
770 .bind(&point_id)
771 .bind(768_i64)
772 .execute(pool)
773 .await
774 .unwrap();
775
776 let unembedded = store.unembedded_message_ids(None).await.unwrap();
777 assert_eq!(unembedded.len(), 1);
778 assert_eq!(unembedded[0].0, msg_id2);
779 assert_eq!(unembedded[0].3, "msg2");
780 }
781
782 #[tokio::test]
783 async fn unembedded_message_ids_respects_limit() {
784 let store = test_store().await;
785 let cid = store.create_conversation().await.unwrap();
786
787 for i in 0..10 {
788 store
789 .save_message(cid, "user", &format!("msg{i}"))
790 .await
791 .unwrap();
792 }
793
794 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
795 assert_eq!(unembedded.len(), 3);
796 }
797
798 #[tokio::test]
799 async fn count_messages_returns_correct_count() {
800 let store = test_store().await;
801 let cid = store.create_conversation().await.unwrap();
802
803 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
804
805 store.save_message(cid, "user", "msg1").await.unwrap();
806 store.save_message(cid, "assistant", "msg2").await.unwrap();
807
808 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
809 }
810
811 #[tokio::test]
812 async fn count_messages_after_filters_correctly() {
813 let store = test_store().await;
814 let cid = store.create_conversation().await.unwrap();
815
816 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
817 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
818 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
819
820 assert_eq!(
821 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
822 3
823 );
824 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
825 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
826 }
827
828 #[tokio::test]
829 async fn load_messages_range_basic() {
830 let store = test_store().await;
831 let cid = store.create_conversation().await.unwrap();
832
833 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
834 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
835 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
836
837 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
838 assert_eq!(msgs.len(), 2);
839 assert_eq!(msgs[0].0, msg_id2);
840 assert_eq!(msgs[0].2, "msg2");
841 assert_eq!(msgs[1].0, msg_id3);
842 assert_eq!(msgs[1].2, "msg3");
843 }
844
845 #[tokio::test]
846 async fn load_messages_range_respects_limit() {
847 let store = test_store().await;
848 let cid = store.create_conversation().await.unwrap();
849
850 store.save_message(cid, "user", "msg1").await.unwrap();
851 store.save_message(cid, "assistant", "msg2").await.unwrap();
852 store.save_message(cid, "user", "msg3").await.unwrap();
853
854 let msgs = store
855 .load_messages_range(cid, MessageId(0), 2)
856 .await
857 .unwrap();
858 assert_eq!(msgs.len(), 2);
859 }
860
861 #[tokio::test]
862 async fn keyword_search_basic() {
863 let store = test_store().await;
864 let cid = store.create_conversation().await.unwrap();
865
866 store
867 .save_message(cid, "user", "rust programming language")
868 .await
869 .unwrap();
870 store
871 .save_message(cid, "assistant", "python is great too")
872 .await
873 .unwrap();
874 store
875 .save_message(cid, "user", "I love rust and cargo")
876 .await
877 .unwrap();
878
879 let results = store.keyword_search("rust", 10, None).await.unwrap();
880 assert_eq!(results.len(), 2);
881 assert!(results.iter().all(|(_, score)| *score > 0.0));
882 }
883
884 #[tokio::test]
885 async fn keyword_search_with_conversation_filter() {
886 let store = test_store().await;
887 let cid1 = store.create_conversation().await.unwrap();
888 let cid2 = store.create_conversation().await.unwrap();
889
890 store
891 .save_message(cid1, "user", "hello world")
892 .await
893 .unwrap();
894 store
895 .save_message(cid2, "user", "hello universe")
896 .await
897 .unwrap();
898
899 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
900 assert_eq!(results.len(), 1);
901 }
902
903 #[tokio::test]
904 async fn keyword_search_no_match() {
905 let store = test_store().await;
906 let cid = store.create_conversation().await.unwrap();
907
908 store
909 .save_message(cid, "user", "hello world")
910 .await
911 .unwrap();
912
913 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
914 assert!(results.is_empty());
915 }
916
917 #[tokio::test]
918 async fn keyword_search_respects_limit() {
919 let store = test_store().await;
920 let cid = store.create_conversation().await.unwrap();
921
922 for i in 0..10 {
923 store
924 .save_message(cid, "user", &format!("test message {i}"))
925 .await
926 .unwrap();
927 }
928
929 let results = store.keyword_search("test", 3, None).await.unwrap();
930 assert_eq!(results.len(), 3);
931 }
932
933 #[tokio::test]
934 async fn save_message_with_metadata_stores_visibility() {
935 let store = test_store().await;
936 let cid = store.create_conversation().await.unwrap();
937
938 let id = store
939 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
940 .await
941 .unwrap();
942
943 let history = store.load_history(cid, 10).await.unwrap();
944 assert_eq!(history.len(), 1);
945 assert!(!history[0].metadata.agent_visible);
946 assert!(history[0].metadata.user_visible);
947 assert_eq!(id, MessageId(1));
948 }
949
950 #[tokio::test]
951 async fn load_history_filtered_by_agent_visible() {
952 let store = test_store().await;
953 let cid = store.create_conversation().await.unwrap();
954
955 store
956 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
957 .await
958 .unwrap();
959 store
960 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
961 .await
962 .unwrap();
963
964 let agent_msgs = store
965 .load_history_filtered(cid, 50, Some(true), None)
966 .await
967 .unwrap();
968 assert_eq!(agent_msgs.len(), 1);
969 assert_eq!(agent_msgs[0].content, "visible to agent");
970 }
971
972 #[tokio::test]
973 async fn load_history_filtered_by_user_visible() {
974 let store = test_store().await;
975 let cid = store.create_conversation().await.unwrap();
976
977 store
978 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
979 .await
980 .unwrap();
981 store
982 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
983 .await
984 .unwrap();
985
986 let user_msgs = store
987 .load_history_filtered(cid, 50, None, Some(true))
988 .await
989 .unwrap();
990 assert_eq!(user_msgs.len(), 1);
991 assert_eq!(user_msgs[0].content, "user sees this");
992 }
993
994 #[tokio::test]
995 async fn load_history_filtered_no_filter_returns_all() {
996 let store = test_store().await;
997 let cid = store.create_conversation().await.unwrap();
998
999 store
1000 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1001 .await
1002 .unwrap();
1003 store
1004 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1005 .await
1006 .unwrap();
1007
1008 let all_msgs = store
1009 .load_history_filtered(cid, 50, None, None)
1010 .await
1011 .unwrap();
1012 assert_eq!(all_msgs.len(), 2);
1013 }
1014
1015 #[tokio::test]
1016 async fn replace_conversation_marks_originals_and_inserts_summary() {
1017 let store = test_store().await;
1018 let cid = store.create_conversation().await.unwrap();
1019
1020 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1021 let id2 = store
1022 .save_message(cid, "assistant", "second")
1023 .await
1024 .unwrap();
1025 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1026
1027 let summary_id = store
1028 .replace_conversation(cid, id1..=id2, "system", "summary text")
1029 .await
1030 .unwrap();
1031
1032 let all = store.load_history(cid, 50).await.unwrap();
1034 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1036 assert!(!by_id1.metadata.agent_visible);
1037 assert!(by_id1.metadata.user_visible);
1038
1039 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1040 assert!(!by_id2.metadata.agent_visible);
1041
1042 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1043 assert!(by_id3.metadata.agent_visible);
1044
1045 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1047 assert!(summary.metadata.agent_visible);
1048 assert!(!summary.metadata.user_visible);
1049 assert!(summary_id > id3);
1050 }
1051
1052 #[tokio::test]
1053 async fn oldest_message_ids_returns_in_order() {
1054 let store = test_store().await;
1055 let cid = store.create_conversation().await.unwrap();
1056
1057 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1058 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1059 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1060
1061 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1062 assert_eq!(ids, vec![id1, id2]);
1063 assert!(ids[0] < ids[1]);
1064
1065 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1066 assert_eq!(all_ids, vec![id1, id2, id3]);
1067 }
1068
1069 #[tokio::test]
1070 async fn message_metadata_default_both_visible() {
1071 let store = test_store().await;
1072 let cid = store.create_conversation().await.unwrap();
1073
1074 store.save_message(cid, "user", "normal").await.unwrap();
1075
1076 let history = store.load_history(cid, 10).await.unwrap();
1077 assert!(history[0].metadata.agent_visible);
1078 assert!(history[0].metadata.user_visible);
1079 assert!(history[0].metadata.compacted_at.is_none());
1080 }
1081}