1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10pub(crate) fn sanitize_fts5_query(query: &str) -> String {
20 query
21 .split(|c: char| !c.is_alphanumeric())
22 .filter(|t| !t.is_empty())
23 .collect::<Vec<_>>()
24 .join(" ")
25}
26
27fn parse_role(s: &str) -> Role {
28 match s {
29 "assistant" => Role::Assistant,
30 "system" => Role::System,
31 _ => Role::User,
32 }
33}
34
35#[must_use]
36pub fn role_str(role: Role) -> &'static str {
37 match role {
38 Role::System => "system",
39 Role::User => "user",
40 Role::Assistant => "assistant",
41 }
42}
43
44fn parse_parts_json(role_str: &str, parts_json: &str) -> Vec<MessagePart> {
49 if parts_json == "[]" {
50 return vec![];
51 }
52 match serde_json::from_str(parts_json) {
53 Ok(p) => p,
54 Err(e) => {
55 let truncated = parts_json.chars().take(120).collect::<String>();
56 tracing::warn!(
57 role = %role_str,
58 parts_json = %truncated,
59 error = %e,
60 "failed to deserialize message parts, falling back to empty"
61 );
62 vec![]
63 }
64 }
65}
66
67impl SqliteStore {
68 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
74 let row: (ConversationId,) =
75 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
76 .fetch_one(&self.pool)
77 .await?;
78 Ok(row.0)
79 }
80
81 pub async fn save_message(
87 &self,
88 conversation_id: ConversationId,
89 role: &str,
90 content: &str,
91 ) -> Result<MessageId, MemoryError> {
92 self.save_message_with_parts(conversation_id, role, content, "[]")
93 .await
94 }
95
96 pub async fn save_message_with_parts(
102 &self,
103 conversation_id: ConversationId,
104 role: &str,
105 content: &str,
106 parts_json: &str,
107 ) -> Result<MessageId, MemoryError> {
108 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
109 .await
110 }
111
112 pub async fn save_message_with_metadata(
118 &self,
119 conversation_id: ConversationId,
120 role: &str,
121 content: &str,
122 parts_json: &str,
123 agent_visible: bool,
124 user_visible: bool,
125 ) -> Result<MessageId, MemoryError> {
126 let row: (MessageId,) = sqlx::query_as(
127 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
128 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
129 )
130 .bind(conversation_id)
131 .bind(role)
132 .bind(content)
133 .bind(parts_json)
134 .bind(i64::from(agent_visible))
135 .bind(i64::from(user_visible))
136 .fetch_one(&self.pool)
137 .await?;
138 Ok(row.0)
139 }
140
141 pub async fn load_history(
147 &self,
148 conversation_id: ConversationId,
149 limit: u32,
150 ) -> Result<Vec<Message>, MemoryError> {
151 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
152 "SELECT role, content, parts, agent_visible, user_visible FROM (\
153 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
154 WHERE conversation_id = ? AND deleted_at IS NULL \
155 ORDER BY id DESC \
156 LIMIT ?\
157 ) ORDER BY id ASC",
158 )
159 .bind(conversation_id)
160 .bind(limit)
161 .fetch_all(&self.pool)
162 .await?;
163
164 let messages = rows
165 .into_iter()
166 .map(
167 |(role_str, content, parts_json, agent_visible, user_visible)| {
168 let parts = parse_parts_json(&role_str, &parts_json);
169 Message {
170 role: parse_role(&role_str),
171 content,
172 parts,
173 metadata: MessageMetadata {
174 agent_visible: agent_visible != 0,
175 user_visible: user_visible != 0,
176 compacted_at: None,
177 deferred_summary: None,
178 focus_pinned: false,
179 focus_marker_id: None,
180 },
181 }
182 },
183 )
184 .collect();
185 Ok(messages)
186 }
187
188 pub async fn load_history_filtered(
196 &self,
197 conversation_id: ConversationId,
198 limit: u32,
199 agent_visible: Option<bool>,
200 user_visible: Option<bool>,
201 ) -> Result<Vec<Message>, MemoryError> {
202 let av = agent_visible.map(i64::from);
203 let uv = user_visible.map(i64::from);
204
205 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
206 "WITH recent AS (\
207 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
208 WHERE conversation_id = ? \
209 AND deleted_at IS NULL \
210 AND (? IS NULL OR agent_visible = ?) \
211 AND (? IS NULL OR user_visible = ?) \
212 ORDER BY id DESC \
213 LIMIT ?\
214 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
215 )
216 .bind(conversation_id)
217 .bind(av)
218 .bind(av)
219 .bind(uv)
220 .bind(uv)
221 .bind(limit)
222 .fetch_all(&self.pool)
223 .await?;
224
225 let messages = rows
226 .into_iter()
227 .map(
228 |(role_str, content, parts_json, agent_visible, user_visible)| {
229 let parts = parse_parts_json(&role_str, &parts_json);
230 Message {
231 role: parse_role(&role_str),
232 content,
233 parts,
234 metadata: MessageMetadata {
235 agent_visible: agent_visible != 0,
236 user_visible: user_visible != 0,
237 compacted_at: None,
238 deferred_summary: None,
239 focus_pinned: false,
240 focus_marker_id: None,
241 },
242 }
243 },
244 )
245 .collect();
246 Ok(messages)
247 }
248
249 pub async fn replace_conversation(
261 &self,
262 conversation_id: ConversationId,
263 compacted_range: std::ops::RangeInclusive<MessageId>,
264 summary_role: &str,
265 summary_content: &str,
266 ) -> Result<MessageId, MemoryError> {
267 let now = {
268 let secs = std::time::SystemTime::now()
269 .duration_since(std::time::UNIX_EPOCH)
270 .unwrap_or_default()
271 .as_secs();
272 format!("{secs}")
273 };
274 let start_id = compacted_range.start().0;
275 let end_id = compacted_range.end().0;
276
277 let mut tx = self.pool.begin().await?;
278
279 sqlx::query(
280 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
281 WHERE conversation_id = ? AND id >= ? AND id <= ?",
282 )
283 .bind(&now)
284 .bind(conversation_id)
285 .bind(start_id)
286 .bind(end_id)
287 .execute(&mut *tx)
288 .await?;
289
290 let row: (MessageId,) = sqlx::query_as(
291 "INSERT INTO messages \
292 (conversation_id, role, content, parts, agent_visible, user_visible) \
293 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
294 )
295 .bind(conversation_id)
296 .bind(summary_role)
297 .bind(summary_content)
298 .fetch_one(&mut *tx)
299 .await?;
300
301 tx.commit().await?;
302
303 Ok(row.0)
304 }
305
306 pub async fn oldest_message_ids(
312 &self,
313 conversation_id: ConversationId,
314 n: u32,
315 ) -> Result<Vec<MessageId>, MemoryError> {
316 let rows: Vec<(MessageId,)> = sqlx::query_as(
317 "SELECT id FROM messages WHERE conversation_id = ? AND deleted_at IS NULL ORDER BY id ASC LIMIT ?",
318 )
319 .bind(conversation_id)
320 .bind(n)
321 .fetch_all(&self.pool)
322 .await?;
323 Ok(rows.into_iter().map(|r| r.0).collect())
324 }
325
326 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
332 let row: Option<(ConversationId,)> =
333 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
334 .fetch_optional(&self.pool)
335 .await?;
336 Ok(row.map(|r| r.0))
337 }
338
339 pub async fn message_by_id(
345 &self,
346 message_id: MessageId,
347 ) -> Result<Option<Message>, MemoryError> {
348 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
349 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ? AND deleted_at IS NULL",
350 )
351 .bind(message_id)
352 .fetch_optional(&self.pool)
353 .await?;
354
355 Ok(row.map(
356 |(role_str, content, parts_json, agent_visible, user_visible)| {
357 let parts = parse_parts_json(&role_str, &parts_json);
358 Message {
359 role: parse_role(&role_str),
360 content,
361 parts,
362 metadata: MessageMetadata {
363 agent_visible: agent_visible != 0,
364 user_visible: user_visible != 0,
365 compacted_at: None,
366 deferred_summary: None,
367 focus_pinned: false,
368 focus_marker_id: None,
369 },
370 }
371 },
372 ))
373 }
374
375 pub async fn messages_by_ids(
381 &self,
382 ids: &[MessageId],
383 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
384 if ids.is_empty() {
385 return Ok(Vec::new());
386 }
387
388 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
389
390 let query = format!(
391 "SELECT id, role, content, parts FROM messages \
392 WHERE id IN ({placeholders}) AND agent_visible = 1 AND deleted_at IS NULL"
393 );
394 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
395 for &id in ids {
396 q = q.bind(id);
397 }
398
399 let rows = q.fetch_all(&self.pool).await?;
400
401 Ok(rows
402 .into_iter()
403 .map(|(id, role_str, content, parts_json)| {
404 let parts = parse_parts_json(&role_str, &parts_json);
405 (
406 id,
407 Message {
408 role: parse_role(&role_str),
409 content,
410 parts,
411 metadata: MessageMetadata::default(),
412 },
413 )
414 })
415 .collect())
416 }
417
418 pub async fn unembedded_message_ids(
424 &self,
425 limit: Option<usize>,
426 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
427 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
428
429 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
430 "SELECT m.id, m.conversation_id, m.role, m.content \
431 FROM messages m \
432 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
433 WHERE em.id IS NULL AND m.deleted_at IS NULL \
434 ORDER BY m.id ASC \
435 LIMIT ?",
436 )
437 .bind(effective_limit)
438 .fetch_all(&self.pool)
439 .await?;
440
441 Ok(rows)
442 }
443
444 pub async fn count_messages(
450 &self,
451 conversation_id: ConversationId,
452 ) -> Result<i64, MemoryError> {
453 let row: (i64,) = sqlx::query_as(
454 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND deleted_at IS NULL",
455 )
456 .bind(conversation_id)
457 .fetch_one(&self.pool)
458 .await?;
459 Ok(row.0)
460 }
461
462 pub async fn count_messages_after(
468 &self,
469 conversation_id: ConversationId,
470 after_id: MessageId,
471 ) -> Result<i64, MemoryError> {
472 let row: (i64,) =
473 sqlx::query_as(
474 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL",
475 )
476 .bind(conversation_id)
477 .bind(after_id)
478 .fetch_one(&self.pool)
479 .await?;
480 Ok(row.0)
481 }
482
483 pub async fn keyword_search(
492 &self,
493 query: &str,
494 limit: usize,
495 conversation_id: Option<ConversationId>,
496 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
497 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
498 let safe_query = sanitize_fts5_query(query);
499 if safe_query.is_empty() {
500 return Ok(Vec::new());
501 }
502
503 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
504 sqlx::query_as(
505 "SELECT m.id, -rank AS score \
506 FROM messages_fts f \
507 JOIN messages m ON m.id = f.rowid \
508 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
509 ORDER BY rank \
510 LIMIT ?",
511 )
512 .bind(&safe_query)
513 .bind(cid)
514 .bind(effective_limit)
515 .fetch_all(&self.pool)
516 .await?
517 } else {
518 sqlx::query_as(
519 "SELECT m.id, -rank AS score \
520 FROM messages_fts f \
521 JOIN messages m ON m.id = f.rowid \
522 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
523 ORDER BY rank \
524 LIMIT ?",
525 )
526 .bind(&safe_query)
527 .bind(effective_limit)
528 .fetch_all(&self.pool)
529 .await?
530 };
531
532 Ok(rows)
533 }
534
535 pub async fn keyword_search_with_time_range(
548 &self,
549 query: &str,
550 limit: usize,
551 conversation_id: Option<ConversationId>,
552 after: Option<&str>,
553 before: Option<&str>,
554 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
555 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
556 let safe_query = sanitize_fts5_query(query);
557 if safe_query.is_empty() {
558 return Ok(Vec::new());
559 }
560
561 let after_clause = if after.is_some() {
563 " AND m.created_at > ?"
564 } else {
565 ""
566 };
567 let before_clause = if before.is_some() {
568 " AND m.created_at < ?"
569 } else {
570 ""
571 };
572 let conv_clause = if conversation_id.is_some() {
573 " AND m.conversation_id = ?"
574 } else {
575 ""
576 };
577
578 let sql = format!(
579 "SELECT m.id, -rank AS score \
580 FROM messages_fts f \
581 JOIN messages m ON m.id = f.rowid \
582 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL\
583 {after_clause}{before_clause}{conv_clause} \
584 ORDER BY rank \
585 LIMIT ?"
586 );
587
588 let mut q = sqlx::query_as::<_, (MessageId, f64)>(&sql).bind(&safe_query);
589 if let Some(a) = after {
590 q = q.bind(a);
591 }
592 if let Some(b) = before {
593 q = q.bind(b);
594 }
595 if let Some(cid) = conversation_id {
596 q = q.bind(cid);
597 }
598 q = q.bind(effective_limit);
599
600 Ok(q.fetch_all(&self.pool).await?)
601 }
602
603 pub async fn message_timestamps(
611 &self,
612 ids: &[MessageId],
613 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
614 if ids.is_empty() {
615 return Ok(std::collections::HashMap::new());
616 }
617
618 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
619 let query = format!(
620 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
621 FROM messages WHERE id IN ({placeholders}) AND deleted_at IS NULL"
622 );
623 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
624 for &id in ids {
625 q = q.bind(id);
626 }
627
628 let rows = q.fetch_all(&self.pool).await?;
629 Ok(rows.into_iter().collect())
630 }
631
632 pub async fn load_messages_range(
638 &self,
639 conversation_id: ConversationId,
640 after_message_id: MessageId,
641 limit: usize,
642 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
643 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
644
645 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
646 "SELECT id, role, content FROM messages \
647 WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL \
648 ORDER BY id ASC LIMIT ?",
649 )
650 .bind(conversation_id)
651 .bind(after_message_id)
652 .bind(effective_limit)
653 .fetch_all(&self.pool)
654 .await?;
655
656 Ok(rows)
657 }
658
659 pub async fn get_eviction_candidates(
667 &self,
668 ) -> Result<Vec<crate::eviction::EvictionEntry>, crate::error::MemoryError> {
669 let rows: Vec<(MessageId, String, Option<String>, i64)> = sqlx::query_as(
670 "SELECT id, created_at, last_accessed, access_count \
671 FROM messages WHERE deleted_at IS NULL",
672 )
673 .fetch_all(&self.pool)
674 .await?;
675
676 Ok(rows
677 .into_iter()
678 .map(
679 |(id, created_at, last_accessed, access_count)| crate::eviction::EvictionEntry {
680 id,
681 created_at,
682 last_accessed,
683 access_count: access_count.try_into().unwrap_or(0),
684 },
685 )
686 .collect())
687 }
688
689 pub async fn soft_delete_messages(
697 &self,
698 ids: &[MessageId],
699 ) -> Result<(), crate::error::MemoryError> {
700 if ids.is_empty() {
701 return Ok(());
702 }
703 for &id in ids {
705 sqlx::query(
706 "UPDATE messages SET deleted_at = datetime('now') WHERE id = ? AND deleted_at IS NULL",
707 )
708 .bind(id)
709 .execute(&self.pool)
710 .await?;
711 }
712 Ok(())
713 }
714
715 pub async fn get_soft_deleted_message_ids(
721 &self,
722 ) -> Result<Vec<MessageId>, crate::error::MemoryError> {
723 let rows: Vec<(MessageId,)> = sqlx::query_as(
724 "SELECT id FROM messages WHERE deleted_at IS NOT NULL AND qdrant_cleaned = 0",
725 )
726 .fetch_all(&self.pool)
727 .await?;
728 Ok(rows.into_iter().map(|(id,)| id).collect())
729 }
730
731 pub async fn mark_qdrant_cleaned(
737 &self,
738 ids: &[MessageId],
739 ) -> Result<(), crate::error::MemoryError> {
740 for &id in ids {
741 sqlx::query("UPDATE messages SET qdrant_cleaned = 1 WHERE id = ?")
742 .bind(id)
743 .execute(&self.pool)
744 .await?;
745 }
746 Ok(())
747 }
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753
754 async fn test_store() -> SqliteStore {
755 SqliteStore::new(":memory:").await.unwrap()
756 }
757
758 #[tokio::test]
759 async fn create_conversation_returns_id() {
760 let store = test_store().await;
761 let id1 = store.create_conversation().await.unwrap();
762 let id2 = store.create_conversation().await.unwrap();
763 assert_eq!(id1, ConversationId(1));
764 assert_eq!(id2, ConversationId(2));
765 }
766
767 #[tokio::test]
768 async fn save_and_load_messages() {
769 let store = test_store().await;
770 let cid = store.create_conversation().await.unwrap();
771
772 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
773 let msg_id2 = store
774 .save_message(cid, "assistant", "hi there")
775 .await
776 .unwrap();
777
778 assert_eq!(msg_id1, MessageId(1));
779 assert_eq!(msg_id2, MessageId(2));
780
781 let history = store.load_history(cid, 50).await.unwrap();
782 assert_eq!(history.len(), 2);
783 assert_eq!(history[0].role, Role::User);
784 assert_eq!(history[0].content, "hello");
785 assert_eq!(history[1].role, Role::Assistant);
786 assert_eq!(history[1].content, "hi there");
787 }
788
789 #[tokio::test]
790 async fn load_history_respects_limit() {
791 let store = test_store().await;
792 let cid = store.create_conversation().await.unwrap();
793
794 for i in 0..10 {
795 store
796 .save_message(cid, "user", &format!("msg {i}"))
797 .await
798 .unwrap();
799 }
800
801 let history = store.load_history(cid, 3).await.unwrap();
802 assert_eq!(history.len(), 3);
803 assert_eq!(history[0].content, "msg 7");
804 assert_eq!(history[1].content, "msg 8");
805 assert_eq!(history[2].content, "msg 9");
806 }
807
808 #[tokio::test]
809 async fn latest_conversation_id_empty() {
810 let store = test_store().await;
811 assert!(store.latest_conversation_id().await.unwrap().is_none());
812 }
813
814 #[tokio::test]
815 async fn latest_conversation_id_returns_newest() {
816 let store = test_store().await;
817 store.create_conversation().await.unwrap();
818 let id2 = store.create_conversation().await.unwrap();
819 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
820 }
821
822 #[tokio::test]
823 async fn messages_isolated_per_conversation() {
824 let store = test_store().await;
825 let cid1 = store.create_conversation().await.unwrap();
826 let cid2 = store.create_conversation().await.unwrap();
827
828 store.save_message(cid1, "user", "conv1").await.unwrap();
829 store.save_message(cid2, "user", "conv2").await.unwrap();
830
831 let h1 = store.load_history(cid1, 50).await.unwrap();
832 let h2 = store.load_history(cid2, 50).await.unwrap();
833 assert_eq!(h1.len(), 1);
834 assert_eq!(h1[0].content, "conv1");
835 assert_eq!(h2.len(), 1);
836 assert_eq!(h2[0].content, "conv2");
837 }
838
839 #[tokio::test]
840 async fn pool_accessor_returns_valid_pool() {
841 let store = test_store().await;
842 let pool = store.pool();
843 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
844 assert_eq!(row.0, 1);
845 }
846
847 #[tokio::test]
848 async fn embeddings_metadata_table_exists() {
849 let store = test_store().await;
850 let result: (i64,) = sqlx::query_as(
851 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
852 )
853 .fetch_one(store.pool())
854 .await
855 .unwrap();
856 assert_eq!(result.0, 1);
857 }
858
859 #[tokio::test]
860 async fn cascade_delete_removes_embeddings_metadata() {
861 let store = test_store().await;
862 let pool = store.pool();
863
864 let cid = store.create_conversation().await.unwrap();
865 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
866
867 let point_id = uuid::Uuid::new_v4().to_string();
868 sqlx::query(
869 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
870 VALUES (?, ?, ?)",
871 )
872 .bind(msg_id)
873 .bind(&point_id)
874 .bind(768_i64)
875 .execute(pool)
876 .await
877 .unwrap();
878
879 let before: (i64,) =
880 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
881 .bind(msg_id)
882 .fetch_one(pool)
883 .await
884 .unwrap();
885 assert_eq!(before.0, 1);
886
887 sqlx::query("DELETE FROM messages WHERE id = ?")
888 .bind(msg_id)
889 .execute(pool)
890 .await
891 .unwrap();
892
893 let after: (i64,) =
894 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
895 .bind(msg_id)
896 .fetch_one(pool)
897 .await
898 .unwrap();
899 assert_eq!(after.0, 0);
900 }
901
902 #[tokio::test]
903 async fn messages_by_ids_batch_fetch() {
904 let store = test_store().await;
905 let cid = store.create_conversation().await.unwrap();
906 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
907 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
908 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
909
910 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
911 assert_eq!(results.len(), 2);
912 assert_eq!(results[0].0, id1);
913 assert_eq!(results[0].1.content, "hello");
914 assert_eq!(results[1].0, id2);
915 assert_eq!(results[1].1.content, "hi");
916 }
917
918 #[tokio::test]
919 async fn messages_by_ids_empty_input() {
920 let store = test_store().await;
921 let results = store.messages_by_ids(&[]).await.unwrap();
922 assert!(results.is_empty());
923 }
924
925 #[tokio::test]
926 async fn messages_by_ids_nonexistent() {
927 let store = test_store().await;
928 let results = store
929 .messages_by_ids(&[MessageId(999), MessageId(1000)])
930 .await
931 .unwrap();
932 assert!(results.is_empty());
933 }
934
935 #[tokio::test]
936 async fn message_by_id_fetches_existing() {
937 let store = test_store().await;
938 let cid = store.create_conversation().await.unwrap();
939 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
940
941 let msg = store.message_by_id(msg_id).await.unwrap();
942 assert!(msg.is_some());
943 let msg = msg.unwrap();
944 assert_eq!(msg.role, Role::User);
945 assert_eq!(msg.content, "hello");
946 }
947
948 #[tokio::test]
949 async fn message_by_id_returns_none_for_nonexistent() {
950 let store = test_store().await;
951 let msg = store.message_by_id(MessageId(999)).await.unwrap();
952 assert!(msg.is_none());
953 }
954
955 #[tokio::test]
956 async fn unembedded_message_ids_returns_all_when_none_embedded() {
957 let store = test_store().await;
958 let cid = store.create_conversation().await.unwrap();
959
960 store.save_message(cid, "user", "msg1").await.unwrap();
961 store.save_message(cid, "assistant", "msg2").await.unwrap();
962
963 let unembedded = store.unembedded_message_ids(None).await.unwrap();
964 assert_eq!(unembedded.len(), 2);
965 assert_eq!(unembedded[0].3, "msg1");
966 assert_eq!(unembedded[1].3, "msg2");
967 }
968
969 #[tokio::test]
970 async fn unembedded_message_ids_excludes_embedded() {
971 let store = test_store().await;
972 let pool = store.pool();
973 let cid = store.create_conversation().await.unwrap();
974
975 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
976 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
977
978 let point_id = uuid::Uuid::new_v4().to_string();
979 sqlx::query(
980 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
981 VALUES (?, ?, ?)",
982 )
983 .bind(msg_id1)
984 .bind(&point_id)
985 .bind(768_i64)
986 .execute(pool)
987 .await
988 .unwrap();
989
990 let unembedded = store.unembedded_message_ids(None).await.unwrap();
991 assert_eq!(unembedded.len(), 1);
992 assert_eq!(unembedded[0].0, msg_id2);
993 assert_eq!(unembedded[0].3, "msg2");
994 }
995
996 #[tokio::test]
997 async fn unembedded_message_ids_respects_limit() {
998 let store = test_store().await;
999 let cid = store.create_conversation().await.unwrap();
1000
1001 for i in 0..10 {
1002 store
1003 .save_message(cid, "user", &format!("msg{i}"))
1004 .await
1005 .unwrap();
1006 }
1007
1008 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
1009 assert_eq!(unembedded.len(), 3);
1010 }
1011
1012 #[tokio::test]
1013 async fn count_messages_returns_correct_count() {
1014 let store = test_store().await;
1015 let cid = store.create_conversation().await.unwrap();
1016
1017 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
1018
1019 store.save_message(cid, "user", "msg1").await.unwrap();
1020 store.save_message(cid, "assistant", "msg2").await.unwrap();
1021
1022 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
1023 }
1024
1025 #[tokio::test]
1026 async fn count_messages_after_filters_correctly() {
1027 let store = test_store().await;
1028 let cid = store.create_conversation().await.unwrap();
1029
1030 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
1031 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
1032 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
1033
1034 assert_eq!(
1035 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
1036 3
1037 );
1038 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
1039 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
1040 }
1041
1042 #[tokio::test]
1043 async fn load_messages_range_basic() {
1044 let store = test_store().await;
1045 let cid = store.create_conversation().await.unwrap();
1046
1047 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
1048 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
1049 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
1050
1051 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
1052 assert_eq!(msgs.len(), 2);
1053 assert_eq!(msgs[0].0, msg_id2);
1054 assert_eq!(msgs[0].2, "msg2");
1055 assert_eq!(msgs[1].0, msg_id3);
1056 assert_eq!(msgs[1].2, "msg3");
1057 }
1058
1059 #[tokio::test]
1060 async fn load_messages_range_respects_limit() {
1061 let store = test_store().await;
1062 let cid = store.create_conversation().await.unwrap();
1063
1064 store.save_message(cid, "user", "msg1").await.unwrap();
1065 store.save_message(cid, "assistant", "msg2").await.unwrap();
1066 store.save_message(cid, "user", "msg3").await.unwrap();
1067
1068 let msgs = store
1069 .load_messages_range(cid, MessageId(0), 2)
1070 .await
1071 .unwrap();
1072 assert_eq!(msgs.len(), 2);
1073 }
1074
1075 #[tokio::test]
1076 async fn keyword_search_basic() {
1077 let store = test_store().await;
1078 let cid = store.create_conversation().await.unwrap();
1079
1080 store
1081 .save_message(cid, "user", "rust programming language")
1082 .await
1083 .unwrap();
1084 store
1085 .save_message(cid, "assistant", "python is great too")
1086 .await
1087 .unwrap();
1088 store
1089 .save_message(cid, "user", "I love rust and cargo")
1090 .await
1091 .unwrap();
1092
1093 let results = store.keyword_search("rust", 10, None).await.unwrap();
1094 assert_eq!(results.len(), 2);
1095 assert!(results.iter().all(|(_, score)| *score > 0.0));
1096 }
1097
1098 #[tokio::test]
1099 async fn keyword_search_with_conversation_filter() {
1100 let store = test_store().await;
1101 let cid1 = store.create_conversation().await.unwrap();
1102 let cid2 = store.create_conversation().await.unwrap();
1103
1104 store
1105 .save_message(cid1, "user", "hello world")
1106 .await
1107 .unwrap();
1108 store
1109 .save_message(cid2, "user", "hello universe")
1110 .await
1111 .unwrap();
1112
1113 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
1114 assert_eq!(results.len(), 1);
1115 }
1116
1117 #[tokio::test]
1118 async fn keyword_search_no_match() {
1119 let store = test_store().await;
1120 let cid = store.create_conversation().await.unwrap();
1121
1122 store
1123 .save_message(cid, "user", "hello world")
1124 .await
1125 .unwrap();
1126
1127 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
1128 assert!(results.is_empty());
1129 }
1130
1131 #[tokio::test]
1132 async fn keyword_search_respects_limit() {
1133 let store = test_store().await;
1134 let cid = store.create_conversation().await.unwrap();
1135
1136 for i in 0..10 {
1137 store
1138 .save_message(cid, "user", &format!("test message {i}"))
1139 .await
1140 .unwrap();
1141 }
1142
1143 let results = store.keyword_search("test", 3, None).await.unwrap();
1144 assert_eq!(results.len(), 3);
1145 }
1146
1147 #[test]
1148 fn sanitize_fts5_query_strips_special_chars() {
1149 assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
1150 assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
1151 assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
1152 assert_eq!(sanitize_fts5_query(" "), "");
1153 assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
1154 }
1155
1156 #[tokio::test]
1157 async fn keyword_search_with_special_chars_does_not_error() {
1158 let store = test_store().await;
1159 let cid = store.create_conversation().await.unwrap();
1160 store
1161 .save_message(cid, "user", "skill audit info")
1162 .await
1163 .unwrap();
1164 store
1167 .keyword_search("skill-audit, confidence=0.1", 10, None)
1168 .await
1169 .unwrap();
1170 }
1171
1172 #[tokio::test]
1173 async fn save_message_with_metadata_stores_visibility() {
1174 let store = test_store().await;
1175 let cid = store.create_conversation().await.unwrap();
1176
1177 let id = store
1178 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
1179 .await
1180 .unwrap();
1181
1182 let history = store.load_history(cid, 10).await.unwrap();
1183 assert_eq!(history.len(), 1);
1184 assert!(!history[0].metadata.agent_visible);
1185 assert!(history[0].metadata.user_visible);
1186 assert_eq!(id, MessageId(1));
1187 }
1188
1189 #[tokio::test]
1190 async fn load_history_filtered_by_agent_visible() {
1191 let store = test_store().await;
1192 let cid = store.create_conversation().await.unwrap();
1193
1194 store
1195 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1196 .await
1197 .unwrap();
1198 store
1199 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1200 .await
1201 .unwrap();
1202
1203 let agent_msgs = store
1204 .load_history_filtered(cid, 50, Some(true), None)
1205 .await
1206 .unwrap();
1207 assert_eq!(agent_msgs.len(), 1);
1208 assert_eq!(agent_msgs[0].content, "visible to agent");
1209 }
1210
1211 #[tokio::test]
1212 async fn load_history_filtered_by_user_visible() {
1213 let store = test_store().await;
1214 let cid = store.create_conversation().await.unwrap();
1215
1216 store
1217 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1218 .await
1219 .unwrap();
1220 store
1221 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1222 .await
1223 .unwrap();
1224
1225 let user_msgs = store
1226 .load_history_filtered(cid, 50, None, Some(true))
1227 .await
1228 .unwrap();
1229 assert_eq!(user_msgs.len(), 1);
1230 assert_eq!(user_msgs[0].content, "user sees this");
1231 }
1232
1233 #[tokio::test]
1234 async fn load_history_filtered_no_filter_returns_all() {
1235 let store = test_store().await;
1236 let cid = store.create_conversation().await.unwrap();
1237
1238 store
1239 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1240 .await
1241 .unwrap();
1242 store
1243 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1244 .await
1245 .unwrap();
1246
1247 let all_msgs = store
1248 .load_history_filtered(cid, 50, None, None)
1249 .await
1250 .unwrap();
1251 assert_eq!(all_msgs.len(), 2);
1252 }
1253
1254 #[tokio::test]
1255 async fn replace_conversation_marks_originals_and_inserts_summary() {
1256 let store = test_store().await;
1257 let cid = store.create_conversation().await.unwrap();
1258
1259 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1260 let id2 = store
1261 .save_message(cid, "assistant", "second")
1262 .await
1263 .unwrap();
1264 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1265
1266 let summary_id = store
1267 .replace_conversation(cid, id1..=id2, "system", "summary text")
1268 .await
1269 .unwrap();
1270
1271 let all = store.load_history(cid, 50).await.unwrap();
1273 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1275 assert!(!by_id1.metadata.agent_visible);
1276 assert!(by_id1.metadata.user_visible);
1277
1278 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1279 assert!(!by_id2.metadata.agent_visible);
1280
1281 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1282 assert!(by_id3.metadata.agent_visible);
1283
1284 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1286 assert!(summary.metadata.agent_visible);
1287 assert!(!summary.metadata.user_visible);
1288 assert!(summary_id > id3);
1289 }
1290
1291 #[tokio::test]
1292 async fn oldest_message_ids_returns_in_order() {
1293 let store = test_store().await;
1294 let cid = store.create_conversation().await.unwrap();
1295
1296 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1297 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1298 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1299
1300 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1301 assert_eq!(ids, vec![id1, id2]);
1302 assert!(ids[0] < ids[1]);
1303
1304 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1305 assert_eq!(all_ids, vec![id1, id2, id3]);
1306 }
1307
1308 #[tokio::test]
1309 async fn message_metadata_default_both_visible() {
1310 let store = test_store().await;
1311 let cid = store.create_conversation().await.unwrap();
1312
1313 store.save_message(cid, "user", "normal").await.unwrap();
1314
1315 let history = store.load_history(cid, 10).await.unwrap();
1316 assert!(history[0].metadata.agent_visible);
1317 assert!(history[0].metadata.user_visible);
1318 assert!(history[0].metadata.compacted_at.is_none());
1319 }
1320
1321 #[tokio::test]
1322 async fn load_history_empty_parts_json_fast_path() {
1323 let store = test_store().await;
1324 let cid = store.create_conversation().await.unwrap();
1325
1326 store
1327 .save_message_with_parts(cid, "user", "hello", "[]")
1328 .await
1329 .unwrap();
1330
1331 let history = store.load_history(cid, 10).await.unwrap();
1332 assert_eq!(history.len(), 1);
1333 assert!(
1334 history[0].parts.is_empty(),
1335 "\"[]\" fast-path must yield empty parts Vec"
1336 );
1337 }
1338
1339 #[tokio::test]
1340 async fn load_history_non_empty_parts_json_parsed() {
1341 let store = test_store().await;
1342 let cid = store.create_conversation().await.unwrap();
1343
1344 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1345 tool_use_id: "t1".into(),
1346 content: "result".into(),
1347 is_error: false,
1348 }])
1349 .unwrap();
1350
1351 store
1352 .save_message_with_parts(cid, "user", "hello", &parts_json)
1353 .await
1354 .unwrap();
1355
1356 let history = store.load_history(cid, 10).await.unwrap();
1357 assert_eq!(history.len(), 1);
1358 assert_eq!(history[0].parts.len(), 1);
1359 assert!(
1360 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1361 );
1362 }
1363
1364 #[tokio::test]
1365 async fn message_by_id_empty_parts_json_fast_path() {
1366 let store = test_store().await;
1367 let cid = store.create_conversation().await.unwrap();
1368
1369 let id = store
1370 .save_message_with_parts(cid, "user", "msg", "[]")
1371 .await
1372 .unwrap();
1373
1374 let msg = store.message_by_id(id).await.unwrap().unwrap();
1375 assert!(
1376 msg.parts.is_empty(),
1377 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1378 );
1379 }
1380
1381 #[tokio::test]
1382 async fn messages_by_ids_empty_parts_json_fast_path() {
1383 let store = test_store().await;
1384 let cid = store.create_conversation().await.unwrap();
1385
1386 let id = store
1387 .save_message_with_parts(cid, "user", "msg", "[]")
1388 .await
1389 .unwrap();
1390
1391 let results = store.messages_by_ids(&[id]).await.unwrap();
1392 assert_eq!(results.len(), 1);
1393 assert!(
1394 results[0].1.parts.is_empty(),
1395 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1396 );
1397 }
1398
1399 #[tokio::test]
1400 async fn load_history_filtered_empty_parts_json_fast_path() {
1401 let store = test_store().await;
1402 let cid = store.create_conversation().await.unwrap();
1403
1404 store
1405 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1406 .await
1407 .unwrap();
1408
1409 let msgs = store
1410 .load_history_filtered(cid, 10, Some(true), None)
1411 .await
1412 .unwrap();
1413 assert_eq!(msgs.len(), 1);
1414 assert!(
1415 msgs[0].parts.is_empty(),
1416 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1417 );
1418 }
1419
1420 #[tokio::test]
1423 async fn keyword_search_with_time_range_empty_query_returns_empty() {
1424 let store = test_store().await;
1425 let cid = store.create_conversation().await.unwrap();
1426 store
1427 .save_message(cid, "user", "rust programming")
1428 .await
1429 .unwrap();
1430
1431 let results = store
1433 .keyword_search_with_time_range("", 10, None, None, None)
1434 .await
1435 .unwrap();
1436 assert!(results.is_empty());
1437 }
1438
1439 #[tokio::test]
1440 async fn keyword_search_with_time_range_no_bounds_matches_like_keyword_search() {
1441 let store = test_store().await;
1442 let cid = store.create_conversation().await.unwrap();
1443 store
1444 .save_message(cid, "user", "rust async programming")
1445 .await
1446 .unwrap();
1447 store
1448 .save_message(cid, "assistant", "python tutorial")
1449 .await
1450 .unwrap();
1451
1452 let results = store
1454 .keyword_search_with_time_range("rust", 10, None, None, None)
1455 .await
1456 .unwrap();
1457 assert_eq!(results.len(), 1);
1458 }
1459
1460 #[tokio::test]
1461 async fn keyword_search_with_time_range_after_bound_excludes_old_messages() {
1462 let store = test_store().await;
1463 let cid = store.create_conversation().await.unwrap();
1464
1465 store
1466 .save_message(cid, "user", "rust programming guide")
1467 .await
1468 .unwrap();
1469 store
1470 .save_message(cid, "user", "rust async patterns")
1471 .await
1472 .unwrap();
1473
1474 let results = store
1476 .keyword_search_with_time_range("rust", 10, None, Some("2099-01-01 00:00:00"), None)
1477 .await
1478 .unwrap();
1479 assert!(results.is_empty(), "no messages after year 2099");
1480 }
1481
1482 #[tokio::test]
1483 async fn keyword_search_with_time_range_before_bound_excludes_future_messages() {
1484 let store = test_store().await;
1485 let cid = store.create_conversation().await.unwrap();
1486
1487 store
1488 .save_message(cid, "user", "rust programming guide")
1489 .await
1490 .unwrap();
1491
1492 let results = store
1494 .keyword_search_with_time_range("rust", 10, None, None, Some("2000-01-01 00:00:00"))
1495 .await
1496 .unwrap();
1497 assert!(results.is_empty(), "no messages before year 2000");
1498 }
1499
1500 #[tokio::test]
1501 async fn keyword_search_with_time_range_wide_bounds_returns_results() {
1502 let store = test_store().await;
1503 let cid = store.create_conversation().await.unwrap();
1504
1505 store
1506 .save_message(cid, "user", "rust programming guide")
1507 .await
1508 .unwrap();
1509 store
1510 .save_message(cid, "assistant", "python basics")
1511 .await
1512 .unwrap();
1513
1514 let results = store
1516 .keyword_search_with_time_range(
1517 "rust",
1518 10,
1519 None,
1520 Some("2000-01-01 00:00:00"),
1521 Some("2099-12-31 23:59:59"),
1522 )
1523 .await
1524 .unwrap();
1525 assert_eq!(results.len(), 1);
1526 }
1527
1528 #[tokio::test]
1529 async fn keyword_search_with_time_range_conversation_filter() {
1530 let store = test_store().await;
1531 let cid1 = store.create_conversation().await.unwrap();
1532 let cid2 = store.create_conversation().await.unwrap();
1533
1534 store
1535 .save_message(cid1, "user", "rust memory safety")
1536 .await
1537 .unwrap();
1538 store
1539 .save_message(cid2, "user", "rust async patterns")
1540 .await
1541 .unwrap();
1542
1543 let results = store
1544 .keyword_search_with_time_range(
1545 "rust",
1546 10,
1547 Some(cid1),
1548 Some("2000-01-01 00:00:00"),
1549 Some("2099-12-31 23:59:59"),
1550 )
1551 .await
1552 .unwrap();
1553 assert_eq!(
1554 results.len(),
1555 1,
1556 "conversation filter must restrict to cid1 only"
1557 );
1558 }
1559}