1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10pub(crate) fn sanitize_fts5_query(query: &str) -> String {
20 query
21 .split(|c: char| !c.is_alphanumeric())
22 .filter(|t| !t.is_empty())
23 .collect::<Vec<_>>()
24 .join(" ")
25}
26
27fn parse_role(s: &str) -> Role {
28 match s {
29 "assistant" => Role::Assistant,
30 "system" => Role::System,
31 _ => Role::User,
32 }
33}
34
35#[must_use]
36pub fn role_str(role: Role) -> &'static str {
37 match role {
38 Role::System => "system",
39 Role::User => "user",
40 Role::Assistant => "assistant",
41 }
42}
43
44fn parse_parts_json(role_str: &str, parts_json: &str) -> Vec<MessagePart> {
49 if parts_json == "[]" {
50 return vec![];
51 }
52 match serde_json::from_str(parts_json) {
53 Ok(p) => p,
54 Err(e) => {
55 let truncated = parts_json.chars().take(120).collect::<String>();
56 tracing::warn!(
57 role = %role_str,
58 parts_json = %truncated,
59 error = %e,
60 "failed to deserialize message parts, falling back to empty"
61 );
62 vec![]
63 }
64 }
65}
66
67impl SqliteStore {
68 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
74 let row: (ConversationId,) =
75 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
76 .fetch_one(&self.pool)
77 .await?;
78 Ok(row.0)
79 }
80
81 pub async fn save_message(
87 &self,
88 conversation_id: ConversationId,
89 role: &str,
90 content: &str,
91 ) -> Result<MessageId, MemoryError> {
92 self.save_message_with_parts(conversation_id, role, content, "[]")
93 .await
94 }
95
96 pub async fn save_message_with_parts(
102 &self,
103 conversation_id: ConversationId,
104 role: &str,
105 content: &str,
106 parts_json: &str,
107 ) -> Result<MessageId, MemoryError> {
108 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
109 .await
110 }
111
112 pub async fn save_message_with_metadata(
118 &self,
119 conversation_id: ConversationId,
120 role: &str,
121 content: &str,
122 parts_json: &str,
123 agent_visible: bool,
124 user_visible: bool,
125 ) -> Result<MessageId, MemoryError> {
126 let row: (MessageId,) = sqlx::query_as(
127 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
128 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
129 )
130 .bind(conversation_id)
131 .bind(role)
132 .bind(content)
133 .bind(parts_json)
134 .bind(i64::from(agent_visible))
135 .bind(i64::from(user_visible))
136 .fetch_one(&self.pool)
137 .await?;
138 Ok(row.0)
139 }
140
141 pub async fn load_history(
147 &self,
148 conversation_id: ConversationId,
149 limit: u32,
150 ) -> Result<Vec<Message>, MemoryError> {
151 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
152 "SELECT role, content, parts, agent_visible, user_visible FROM (\
153 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
154 WHERE conversation_id = ? AND deleted_at IS NULL \
155 ORDER BY id DESC \
156 LIMIT ?\
157 ) ORDER BY id ASC",
158 )
159 .bind(conversation_id)
160 .bind(limit)
161 .fetch_all(&self.pool)
162 .await?;
163
164 let messages = rows
165 .into_iter()
166 .map(
167 |(role_str, content, parts_json, agent_visible, user_visible)| {
168 let parts = parse_parts_json(&role_str, &parts_json);
169 Message {
170 role: parse_role(&role_str),
171 content,
172 parts,
173 metadata: MessageMetadata {
174 agent_visible: agent_visible != 0,
175 user_visible: user_visible != 0,
176 compacted_at: None,
177 deferred_summary: None,
178 },
179 }
180 },
181 )
182 .collect();
183 Ok(messages)
184 }
185
186 pub async fn load_history_filtered(
194 &self,
195 conversation_id: ConversationId,
196 limit: u32,
197 agent_visible: Option<bool>,
198 user_visible: Option<bool>,
199 ) -> Result<Vec<Message>, MemoryError> {
200 let av = agent_visible.map(i64::from);
201 let uv = user_visible.map(i64::from);
202
203 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
204 "WITH recent AS (\
205 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
206 WHERE conversation_id = ? \
207 AND deleted_at IS NULL \
208 AND (? IS NULL OR agent_visible = ?) \
209 AND (? IS NULL OR user_visible = ?) \
210 ORDER BY id DESC \
211 LIMIT ?\
212 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
213 )
214 .bind(conversation_id)
215 .bind(av)
216 .bind(av)
217 .bind(uv)
218 .bind(uv)
219 .bind(limit)
220 .fetch_all(&self.pool)
221 .await?;
222
223 let messages = rows
224 .into_iter()
225 .map(
226 |(role_str, content, parts_json, agent_visible, user_visible)| {
227 let parts = parse_parts_json(&role_str, &parts_json);
228 Message {
229 role: parse_role(&role_str),
230 content,
231 parts,
232 metadata: MessageMetadata {
233 agent_visible: agent_visible != 0,
234 user_visible: user_visible != 0,
235 compacted_at: None,
236 deferred_summary: None,
237 },
238 }
239 },
240 )
241 .collect();
242 Ok(messages)
243 }
244
245 pub async fn replace_conversation(
257 &self,
258 conversation_id: ConversationId,
259 compacted_range: std::ops::RangeInclusive<MessageId>,
260 summary_role: &str,
261 summary_content: &str,
262 ) -> Result<MessageId, MemoryError> {
263 let now = {
264 let secs = std::time::SystemTime::now()
265 .duration_since(std::time::UNIX_EPOCH)
266 .unwrap_or_default()
267 .as_secs();
268 format!("{secs}")
269 };
270 let start_id = compacted_range.start().0;
271 let end_id = compacted_range.end().0;
272
273 let mut tx = self.pool.begin().await?;
274
275 sqlx::query(
276 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
277 WHERE conversation_id = ? AND id >= ? AND id <= ?",
278 )
279 .bind(&now)
280 .bind(conversation_id)
281 .bind(start_id)
282 .bind(end_id)
283 .execute(&mut *tx)
284 .await?;
285
286 let row: (MessageId,) = sqlx::query_as(
287 "INSERT INTO messages \
288 (conversation_id, role, content, parts, agent_visible, user_visible) \
289 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
290 )
291 .bind(conversation_id)
292 .bind(summary_role)
293 .bind(summary_content)
294 .fetch_one(&mut *tx)
295 .await?;
296
297 tx.commit().await?;
298
299 Ok(row.0)
300 }
301
302 pub async fn oldest_message_ids(
308 &self,
309 conversation_id: ConversationId,
310 n: u32,
311 ) -> Result<Vec<MessageId>, MemoryError> {
312 let rows: Vec<(MessageId,)> = sqlx::query_as(
313 "SELECT id FROM messages WHERE conversation_id = ? AND deleted_at IS NULL ORDER BY id ASC LIMIT ?",
314 )
315 .bind(conversation_id)
316 .bind(n)
317 .fetch_all(&self.pool)
318 .await?;
319 Ok(rows.into_iter().map(|r| r.0).collect())
320 }
321
322 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
328 let row: Option<(ConversationId,)> =
329 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
330 .fetch_optional(&self.pool)
331 .await?;
332 Ok(row.map(|r| r.0))
333 }
334
335 pub async fn message_by_id(
341 &self,
342 message_id: MessageId,
343 ) -> Result<Option<Message>, MemoryError> {
344 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
345 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ? AND deleted_at IS NULL",
346 )
347 .bind(message_id)
348 .fetch_optional(&self.pool)
349 .await?;
350
351 Ok(row.map(
352 |(role_str, content, parts_json, agent_visible, user_visible)| {
353 let parts = parse_parts_json(&role_str, &parts_json);
354 Message {
355 role: parse_role(&role_str),
356 content,
357 parts,
358 metadata: MessageMetadata {
359 agent_visible: agent_visible != 0,
360 user_visible: user_visible != 0,
361 compacted_at: None,
362 deferred_summary: None,
363 },
364 }
365 },
366 ))
367 }
368
369 pub async fn messages_by_ids(
375 &self,
376 ids: &[MessageId],
377 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
378 if ids.is_empty() {
379 return Ok(Vec::new());
380 }
381
382 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
383
384 let query = format!(
385 "SELECT id, role, content, parts FROM messages \
386 WHERE id IN ({placeholders}) AND agent_visible = 1 AND deleted_at IS NULL"
387 );
388 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
389 for &id in ids {
390 q = q.bind(id);
391 }
392
393 let rows = q.fetch_all(&self.pool).await?;
394
395 Ok(rows
396 .into_iter()
397 .map(|(id, role_str, content, parts_json)| {
398 let parts = parse_parts_json(&role_str, &parts_json);
399 (
400 id,
401 Message {
402 role: parse_role(&role_str),
403 content,
404 parts,
405 metadata: MessageMetadata::default(),
406 },
407 )
408 })
409 .collect())
410 }
411
412 pub async fn unembedded_message_ids(
418 &self,
419 limit: Option<usize>,
420 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
421 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
422
423 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
424 "SELECT m.id, m.conversation_id, m.role, m.content \
425 FROM messages m \
426 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
427 WHERE em.id IS NULL AND m.deleted_at IS NULL \
428 ORDER BY m.id ASC \
429 LIMIT ?",
430 )
431 .bind(effective_limit)
432 .fetch_all(&self.pool)
433 .await?;
434
435 Ok(rows)
436 }
437
438 pub async fn count_messages(
444 &self,
445 conversation_id: ConversationId,
446 ) -> Result<i64, MemoryError> {
447 let row: (i64,) = sqlx::query_as(
448 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND deleted_at IS NULL",
449 )
450 .bind(conversation_id)
451 .fetch_one(&self.pool)
452 .await?;
453 Ok(row.0)
454 }
455
456 pub async fn count_messages_after(
462 &self,
463 conversation_id: ConversationId,
464 after_id: MessageId,
465 ) -> Result<i64, MemoryError> {
466 let row: (i64,) =
467 sqlx::query_as(
468 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL",
469 )
470 .bind(conversation_id)
471 .bind(after_id)
472 .fetch_one(&self.pool)
473 .await?;
474 Ok(row.0)
475 }
476
477 pub async fn keyword_search(
486 &self,
487 query: &str,
488 limit: usize,
489 conversation_id: Option<ConversationId>,
490 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
491 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
492 let safe_query = sanitize_fts5_query(query);
493 if safe_query.is_empty() {
494 return Ok(Vec::new());
495 }
496
497 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
498 sqlx::query_as(
499 "SELECT m.id, -rank AS score \
500 FROM messages_fts f \
501 JOIN messages m ON m.id = f.rowid \
502 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
503 ORDER BY rank \
504 LIMIT ?",
505 )
506 .bind(&safe_query)
507 .bind(cid)
508 .bind(effective_limit)
509 .fetch_all(&self.pool)
510 .await?
511 } else {
512 sqlx::query_as(
513 "SELECT m.id, -rank AS score \
514 FROM messages_fts f \
515 JOIN messages m ON m.id = f.rowid \
516 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
517 ORDER BY rank \
518 LIMIT ?",
519 )
520 .bind(&safe_query)
521 .bind(effective_limit)
522 .fetch_all(&self.pool)
523 .await?
524 };
525
526 Ok(rows)
527 }
528
529 pub async fn keyword_search_with_time_range(
542 &self,
543 query: &str,
544 limit: usize,
545 conversation_id: Option<ConversationId>,
546 after: Option<&str>,
547 before: Option<&str>,
548 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
549 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
550 let safe_query = sanitize_fts5_query(query);
551 if safe_query.is_empty() {
552 return Ok(Vec::new());
553 }
554
555 let after_clause = if after.is_some() {
557 " AND m.created_at > ?"
558 } else {
559 ""
560 };
561 let before_clause = if before.is_some() {
562 " AND m.created_at < ?"
563 } else {
564 ""
565 };
566 let conv_clause = if conversation_id.is_some() {
567 " AND m.conversation_id = ?"
568 } else {
569 ""
570 };
571
572 let sql = format!(
573 "SELECT m.id, -rank AS score \
574 FROM messages_fts f \
575 JOIN messages m ON m.id = f.rowid \
576 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL\
577 {after_clause}{before_clause}{conv_clause} \
578 ORDER BY rank \
579 LIMIT ?"
580 );
581
582 let mut q = sqlx::query_as::<_, (MessageId, f64)>(&sql).bind(&safe_query);
583 if let Some(a) = after {
584 q = q.bind(a);
585 }
586 if let Some(b) = before {
587 q = q.bind(b);
588 }
589 if let Some(cid) = conversation_id {
590 q = q.bind(cid);
591 }
592 q = q.bind(effective_limit);
593
594 Ok(q.fetch_all(&self.pool).await?)
595 }
596
597 pub async fn message_timestamps(
605 &self,
606 ids: &[MessageId],
607 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
608 if ids.is_empty() {
609 return Ok(std::collections::HashMap::new());
610 }
611
612 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
613 let query = format!(
614 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
615 FROM messages WHERE id IN ({placeholders}) AND deleted_at IS NULL"
616 );
617 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
618 for &id in ids {
619 q = q.bind(id);
620 }
621
622 let rows = q.fetch_all(&self.pool).await?;
623 Ok(rows.into_iter().collect())
624 }
625
626 pub async fn load_messages_range(
632 &self,
633 conversation_id: ConversationId,
634 after_message_id: MessageId,
635 limit: usize,
636 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
637 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
638
639 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
640 "SELECT id, role, content FROM messages \
641 WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL \
642 ORDER BY id ASC LIMIT ?",
643 )
644 .bind(conversation_id)
645 .bind(after_message_id)
646 .bind(effective_limit)
647 .fetch_all(&self.pool)
648 .await?;
649
650 Ok(rows)
651 }
652
653 pub async fn get_eviction_candidates(
661 &self,
662 ) -> Result<Vec<crate::eviction::EvictionEntry>, crate::error::MemoryError> {
663 let rows: Vec<(MessageId, String, Option<String>, i64)> = sqlx::query_as(
664 "SELECT id, created_at, last_accessed, access_count \
665 FROM messages WHERE deleted_at IS NULL",
666 )
667 .fetch_all(&self.pool)
668 .await?;
669
670 Ok(rows
671 .into_iter()
672 .map(
673 |(id, created_at, last_accessed, access_count)| crate::eviction::EvictionEntry {
674 id,
675 created_at,
676 last_accessed,
677 access_count: access_count.try_into().unwrap_or(0),
678 },
679 )
680 .collect())
681 }
682
683 pub async fn soft_delete_messages(
691 &self,
692 ids: &[MessageId],
693 ) -> Result<(), crate::error::MemoryError> {
694 if ids.is_empty() {
695 return Ok(());
696 }
697 for &id in ids {
699 sqlx::query(
700 "UPDATE messages SET deleted_at = datetime('now') WHERE id = ? AND deleted_at IS NULL",
701 )
702 .bind(id)
703 .execute(&self.pool)
704 .await?;
705 }
706 Ok(())
707 }
708
709 pub async fn get_soft_deleted_message_ids(
715 &self,
716 ) -> Result<Vec<MessageId>, crate::error::MemoryError> {
717 let rows: Vec<(MessageId,)> = sqlx::query_as(
718 "SELECT id FROM messages WHERE deleted_at IS NOT NULL AND qdrant_cleaned = 0",
719 )
720 .fetch_all(&self.pool)
721 .await?;
722 Ok(rows.into_iter().map(|(id,)| id).collect())
723 }
724
725 pub async fn mark_qdrant_cleaned(
731 &self,
732 ids: &[MessageId],
733 ) -> Result<(), crate::error::MemoryError> {
734 for &id in ids {
735 sqlx::query("UPDATE messages SET qdrant_cleaned = 1 WHERE id = ?")
736 .bind(id)
737 .execute(&self.pool)
738 .await?;
739 }
740 Ok(())
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747
748 async fn test_store() -> SqliteStore {
749 SqliteStore::new(":memory:").await.unwrap()
750 }
751
752 #[tokio::test]
753 async fn create_conversation_returns_id() {
754 let store = test_store().await;
755 let id1 = store.create_conversation().await.unwrap();
756 let id2 = store.create_conversation().await.unwrap();
757 assert_eq!(id1, ConversationId(1));
758 assert_eq!(id2, ConversationId(2));
759 }
760
761 #[tokio::test]
762 async fn save_and_load_messages() {
763 let store = test_store().await;
764 let cid = store.create_conversation().await.unwrap();
765
766 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
767 let msg_id2 = store
768 .save_message(cid, "assistant", "hi there")
769 .await
770 .unwrap();
771
772 assert_eq!(msg_id1, MessageId(1));
773 assert_eq!(msg_id2, MessageId(2));
774
775 let history = store.load_history(cid, 50).await.unwrap();
776 assert_eq!(history.len(), 2);
777 assert_eq!(history[0].role, Role::User);
778 assert_eq!(history[0].content, "hello");
779 assert_eq!(history[1].role, Role::Assistant);
780 assert_eq!(history[1].content, "hi there");
781 }
782
783 #[tokio::test]
784 async fn load_history_respects_limit() {
785 let store = test_store().await;
786 let cid = store.create_conversation().await.unwrap();
787
788 for i in 0..10 {
789 store
790 .save_message(cid, "user", &format!("msg {i}"))
791 .await
792 .unwrap();
793 }
794
795 let history = store.load_history(cid, 3).await.unwrap();
796 assert_eq!(history.len(), 3);
797 assert_eq!(history[0].content, "msg 7");
798 assert_eq!(history[1].content, "msg 8");
799 assert_eq!(history[2].content, "msg 9");
800 }
801
802 #[tokio::test]
803 async fn latest_conversation_id_empty() {
804 let store = test_store().await;
805 assert!(store.latest_conversation_id().await.unwrap().is_none());
806 }
807
808 #[tokio::test]
809 async fn latest_conversation_id_returns_newest() {
810 let store = test_store().await;
811 store.create_conversation().await.unwrap();
812 let id2 = store.create_conversation().await.unwrap();
813 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
814 }
815
816 #[tokio::test]
817 async fn messages_isolated_per_conversation() {
818 let store = test_store().await;
819 let cid1 = store.create_conversation().await.unwrap();
820 let cid2 = store.create_conversation().await.unwrap();
821
822 store.save_message(cid1, "user", "conv1").await.unwrap();
823 store.save_message(cid2, "user", "conv2").await.unwrap();
824
825 let h1 = store.load_history(cid1, 50).await.unwrap();
826 let h2 = store.load_history(cid2, 50).await.unwrap();
827 assert_eq!(h1.len(), 1);
828 assert_eq!(h1[0].content, "conv1");
829 assert_eq!(h2.len(), 1);
830 assert_eq!(h2[0].content, "conv2");
831 }
832
833 #[tokio::test]
834 async fn pool_accessor_returns_valid_pool() {
835 let store = test_store().await;
836 let pool = store.pool();
837 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
838 assert_eq!(row.0, 1);
839 }
840
841 #[tokio::test]
842 async fn embeddings_metadata_table_exists() {
843 let store = test_store().await;
844 let result: (i64,) = sqlx::query_as(
845 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
846 )
847 .fetch_one(store.pool())
848 .await
849 .unwrap();
850 assert_eq!(result.0, 1);
851 }
852
853 #[tokio::test]
854 async fn cascade_delete_removes_embeddings_metadata() {
855 let store = test_store().await;
856 let pool = store.pool();
857
858 let cid = store.create_conversation().await.unwrap();
859 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
860
861 let point_id = uuid::Uuid::new_v4().to_string();
862 sqlx::query(
863 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
864 VALUES (?, ?, ?)",
865 )
866 .bind(msg_id)
867 .bind(&point_id)
868 .bind(768_i64)
869 .execute(pool)
870 .await
871 .unwrap();
872
873 let before: (i64,) =
874 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
875 .bind(msg_id)
876 .fetch_one(pool)
877 .await
878 .unwrap();
879 assert_eq!(before.0, 1);
880
881 sqlx::query("DELETE FROM messages WHERE id = ?")
882 .bind(msg_id)
883 .execute(pool)
884 .await
885 .unwrap();
886
887 let after: (i64,) =
888 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
889 .bind(msg_id)
890 .fetch_one(pool)
891 .await
892 .unwrap();
893 assert_eq!(after.0, 0);
894 }
895
896 #[tokio::test]
897 async fn messages_by_ids_batch_fetch() {
898 let store = test_store().await;
899 let cid = store.create_conversation().await.unwrap();
900 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
901 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
902 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
903
904 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
905 assert_eq!(results.len(), 2);
906 assert_eq!(results[0].0, id1);
907 assert_eq!(results[0].1.content, "hello");
908 assert_eq!(results[1].0, id2);
909 assert_eq!(results[1].1.content, "hi");
910 }
911
912 #[tokio::test]
913 async fn messages_by_ids_empty_input() {
914 let store = test_store().await;
915 let results = store.messages_by_ids(&[]).await.unwrap();
916 assert!(results.is_empty());
917 }
918
919 #[tokio::test]
920 async fn messages_by_ids_nonexistent() {
921 let store = test_store().await;
922 let results = store
923 .messages_by_ids(&[MessageId(999), MessageId(1000)])
924 .await
925 .unwrap();
926 assert!(results.is_empty());
927 }
928
929 #[tokio::test]
930 async fn message_by_id_fetches_existing() {
931 let store = test_store().await;
932 let cid = store.create_conversation().await.unwrap();
933 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
934
935 let msg = store.message_by_id(msg_id).await.unwrap();
936 assert!(msg.is_some());
937 let msg = msg.unwrap();
938 assert_eq!(msg.role, Role::User);
939 assert_eq!(msg.content, "hello");
940 }
941
942 #[tokio::test]
943 async fn message_by_id_returns_none_for_nonexistent() {
944 let store = test_store().await;
945 let msg = store.message_by_id(MessageId(999)).await.unwrap();
946 assert!(msg.is_none());
947 }
948
949 #[tokio::test]
950 async fn unembedded_message_ids_returns_all_when_none_embedded() {
951 let store = test_store().await;
952 let cid = store.create_conversation().await.unwrap();
953
954 store.save_message(cid, "user", "msg1").await.unwrap();
955 store.save_message(cid, "assistant", "msg2").await.unwrap();
956
957 let unembedded = store.unembedded_message_ids(None).await.unwrap();
958 assert_eq!(unembedded.len(), 2);
959 assert_eq!(unembedded[0].3, "msg1");
960 assert_eq!(unembedded[1].3, "msg2");
961 }
962
963 #[tokio::test]
964 async fn unembedded_message_ids_excludes_embedded() {
965 let store = test_store().await;
966 let pool = store.pool();
967 let cid = store.create_conversation().await.unwrap();
968
969 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
970 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
971
972 let point_id = uuid::Uuid::new_v4().to_string();
973 sqlx::query(
974 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
975 VALUES (?, ?, ?)",
976 )
977 .bind(msg_id1)
978 .bind(&point_id)
979 .bind(768_i64)
980 .execute(pool)
981 .await
982 .unwrap();
983
984 let unembedded = store.unembedded_message_ids(None).await.unwrap();
985 assert_eq!(unembedded.len(), 1);
986 assert_eq!(unembedded[0].0, msg_id2);
987 assert_eq!(unembedded[0].3, "msg2");
988 }
989
990 #[tokio::test]
991 async fn unembedded_message_ids_respects_limit() {
992 let store = test_store().await;
993 let cid = store.create_conversation().await.unwrap();
994
995 for i in 0..10 {
996 store
997 .save_message(cid, "user", &format!("msg{i}"))
998 .await
999 .unwrap();
1000 }
1001
1002 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
1003 assert_eq!(unembedded.len(), 3);
1004 }
1005
1006 #[tokio::test]
1007 async fn count_messages_returns_correct_count() {
1008 let store = test_store().await;
1009 let cid = store.create_conversation().await.unwrap();
1010
1011 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
1012
1013 store.save_message(cid, "user", "msg1").await.unwrap();
1014 store.save_message(cid, "assistant", "msg2").await.unwrap();
1015
1016 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
1017 }
1018
1019 #[tokio::test]
1020 async fn count_messages_after_filters_correctly() {
1021 let store = test_store().await;
1022 let cid = store.create_conversation().await.unwrap();
1023
1024 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
1025 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
1026 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
1027
1028 assert_eq!(
1029 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
1030 3
1031 );
1032 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
1033 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
1034 }
1035
1036 #[tokio::test]
1037 async fn load_messages_range_basic() {
1038 let store = test_store().await;
1039 let cid = store.create_conversation().await.unwrap();
1040
1041 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
1042 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
1043 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
1044
1045 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
1046 assert_eq!(msgs.len(), 2);
1047 assert_eq!(msgs[0].0, msg_id2);
1048 assert_eq!(msgs[0].2, "msg2");
1049 assert_eq!(msgs[1].0, msg_id3);
1050 assert_eq!(msgs[1].2, "msg3");
1051 }
1052
1053 #[tokio::test]
1054 async fn load_messages_range_respects_limit() {
1055 let store = test_store().await;
1056 let cid = store.create_conversation().await.unwrap();
1057
1058 store.save_message(cid, "user", "msg1").await.unwrap();
1059 store.save_message(cid, "assistant", "msg2").await.unwrap();
1060 store.save_message(cid, "user", "msg3").await.unwrap();
1061
1062 let msgs = store
1063 .load_messages_range(cid, MessageId(0), 2)
1064 .await
1065 .unwrap();
1066 assert_eq!(msgs.len(), 2);
1067 }
1068
1069 #[tokio::test]
1070 async fn keyword_search_basic() {
1071 let store = test_store().await;
1072 let cid = store.create_conversation().await.unwrap();
1073
1074 store
1075 .save_message(cid, "user", "rust programming language")
1076 .await
1077 .unwrap();
1078 store
1079 .save_message(cid, "assistant", "python is great too")
1080 .await
1081 .unwrap();
1082 store
1083 .save_message(cid, "user", "I love rust and cargo")
1084 .await
1085 .unwrap();
1086
1087 let results = store.keyword_search("rust", 10, None).await.unwrap();
1088 assert_eq!(results.len(), 2);
1089 assert!(results.iter().all(|(_, score)| *score > 0.0));
1090 }
1091
1092 #[tokio::test]
1093 async fn keyword_search_with_conversation_filter() {
1094 let store = test_store().await;
1095 let cid1 = store.create_conversation().await.unwrap();
1096 let cid2 = store.create_conversation().await.unwrap();
1097
1098 store
1099 .save_message(cid1, "user", "hello world")
1100 .await
1101 .unwrap();
1102 store
1103 .save_message(cid2, "user", "hello universe")
1104 .await
1105 .unwrap();
1106
1107 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
1108 assert_eq!(results.len(), 1);
1109 }
1110
1111 #[tokio::test]
1112 async fn keyword_search_no_match() {
1113 let store = test_store().await;
1114 let cid = store.create_conversation().await.unwrap();
1115
1116 store
1117 .save_message(cid, "user", "hello world")
1118 .await
1119 .unwrap();
1120
1121 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
1122 assert!(results.is_empty());
1123 }
1124
1125 #[tokio::test]
1126 async fn keyword_search_respects_limit() {
1127 let store = test_store().await;
1128 let cid = store.create_conversation().await.unwrap();
1129
1130 for i in 0..10 {
1131 store
1132 .save_message(cid, "user", &format!("test message {i}"))
1133 .await
1134 .unwrap();
1135 }
1136
1137 let results = store.keyword_search("test", 3, None).await.unwrap();
1138 assert_eq!(results.len(), 3);
1139 }
1140
1141 #[test]
1142 fn sanitize_fts5_query_strips_special_chars() {
1143 assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
1144 assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
1145 assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
1146 assert_eq!(sanitize_fts5_query(" "), "");
1147 assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
1148 }
1149
1150 #[tokio::test]
1151 async fn keyword_search_with_special_chars_does_not_error() {
1152 let store = test_store().await;
1153 let cid = store.create_conversation().await.unwrap();
1154 store
1155 .save_message(cid, "user", "skill audit info")
1156 .await
1157 .unwrap();
1158 store
1161 .keyword_search("skill-audit, confidence=0.1", 10, None)
1162 .await
1163 .unwrap();
1164 }
1165
1166 #[tokio::test]
1167 async fn save_message_with_metadata_stores_visibility() {
1168 let store = test_store().await;
1169 let cid = store.create_conversation().await.unwrap();
1170
1171 let id = store
1172 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
1173 .await
1174 .unwrap();
1175
1176 let history = store.load_history(cid, 10).await.unwrap();
1177 assert_eq!(history.len(), 1);
1178 assert!(!history[0].metadata.agent_visible);
1179 assert!(history[0].metadata.user_visible);
1180 assert_eq!(id, MessageId(1));
1181 }
1182
1183 #[tokio::test]
1184 async fn load_history_filtered_by_agent_visible() {
1185 let store = test_store().await;
1186 let cid = store.create_conversation().await.unwrap();
1187
1188 store
1189 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1190 .await
1191 .unwrap();
1192 store
1193 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1194 .await
1195 .unwrap();
1196
1197 let agent_msgs = store
1198 .load_history_filtered(cid, 50, Some(true), None)
1199 .await
1200 .unwrap();
1201 assert_eq!(agent_msgs.len(), 1);
1202 assert_eq!(agent_msgs[0].content, "visible to agent");
1203 }
1204
1205 #[tokio::test]
1206 async fn load_history_filtered_by_user_visible() {
1207 let store = test_store().await;
1208 let cid = store.create_conversation().await.unwrap();
1209
1210 store
1211 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1212 .await
1213 .unwrap();
1214 store
1215 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1216 .await
1217 .unwrap();
1218
1219 let user_msgs = store
1220 .load_history_filtered(cid, 50, None, Some(true))
1221 .await
1222 .unwrap();
1223 assert_eq!(user_msgs.len(), 1);
1224 assert_eq!(user_msgs[0].content, "user sees this");
1225 }
1226
1227 #[tokio::test]
1228 async fn load_history_filtered_no_filter_returns_all() {
1229 let store = test_store().await;
1230 let cid = store.create_conversation().await.unwrap();
1231
1232 store
1233 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1234 .await
1235 .unwrap();
1236 store
1237 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1238 .await
1239 .unwrap();
1240
1241 let all_msgs = store
1242 .load_history_filtered(cid, 50, None, None)
1243 .await
1244 .unwrap();
1245 assert_eq!(all_msgs.len(), 2);
1246 }
1247
1248 #[tokio::test]
1249 async fn replace_conversation_marks_originals_and_inserts_summary() {
1250 let store = test_store().await;
1251 let cid = store.create_conversation().await.unwrap();
1252
1253 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1254 let id2 = store
1255 .save_message(cid, "assistant", "second")
1256 .await
1257 .unwrap();
1258 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1259
1260 let summary_id = store
1261 .replace_conversation(cid, id1..=id2, "system", "summary text")
1262 .await
1263 .unwrap();
1264
1265 let all = store.load_history(cid, 50).await.unwrap();
1267 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1269 assert!(!by_id1.metadata.agent_visible);
1270 assert!(by_id1.metadata.user_visible);
1271
1272 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1273 assert!(!by_id2.metadata.agent_visible);
1274
1275 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1276 assert!(by_id3.metadata.agent_visible);
1277
1278 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1280 assert!(summary.metadata.agent_visible);
1281 assert!(!summary.metadata.user_visible);
1282 assert!(summary_id > id3);
1283 }
1284
1285 #[tokio::test]
1286 async fn oldest_message_ids_returns_in_order() {
1287 let store = test_store().await;
1288 let cid = store.create_conversation().await.unwrap();
1289
1290 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1291 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1292 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1293
1294 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1295 assert_eq!(ids, vec![id1, id2]);
1296 assert!(ids[0] < ids[1]);
1297
1298 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1299 assert_eq!(all_ids, vec![id1, id2, id3]);
1300 }
1301
1302 #[tokio::test]
1303 async fn message_metadata_default_both_visible() {
1304 let store = test_store().await;
1305 let cid = store.create_conversation().await.unwrap();
1306
1307 store.save_message(cid, "user", "normal").await.unwrap();
1308
1309 let history = store.load_history(cid, 10).await.unwrap();
1310 assert!(history[0].metadata.agent_visible);
1311 assert!(history[0].metadata.user_visible);
1312 assert!(history[0].metadata.compacted_at.is_none());
1313 }
1314
1315 #[tokio::test]
1316 async fn load_history_empty_parts_json_fast_path() {
1317 let store = test_store().await;
1318 let cid = store.create_conversation().await.unwrap();
1319
1320 store
1321 .save_message_with_parts(cid, "user", "hello", "[]")
1322 .await
1323 .unwrap();
1324
1325 let history = store.load_history(cid, 10).await.unwrap();
1326 assert_eq!(history.len(), 1);
1327 assert!(
1328 history[0].parts.is_empty(),
1329 "\"[]\" fast-path must yield empty parts Vec"
1330 );
1331 }
1332
1333 #[tokio::test]
1334 async fn load_history_non_empty_parts_json_parsed() {
1335 let store = test_store().await;
1336 let cid = store.create_conversation().await.unwrap();
1337
1338 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1339 tool_use_id: "t1".into(),
1340 content: "result".into(),
1341 is_error: false,
1342 }])
1343 .unwrap();
1344
1345 store
1346 .save_message_with_parts(cid, "user", "hello", &parts_json)
1347 .await
1348 .unwrap();
1349
1350 let history = store.load_history(cid, 10).await.unwrap();
1351 assert_eq!(history.len(), 1);
1352 assert_eq!(history[0].parts.len(), 1);
1353 assert!(
1354 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1355 );
1356 }
1357
1358 #[tokio::test]
1359 async fn message_by_id_empty_parts_json_fast_path() {
1360 let store = test_store().await;
1361 let cid = store.create_conversation().await.unwrap();
1362
1363 let id = store
1364 .save_message_with_parts(cid, "user", "msg", "[]")
1365 .await
1366 .unwrap();
1367
1368 let msg = store.message_by_id(id).await.unwrap().unwrap();
1369 assert!(
1370 msg.parts.is_empty(),
1371 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1372 );
1373 }
1374
1375 #[tokio::test]
1376 async fn messages_by_ids_empty_parts_json_fast_path() {
1377 let store = test_store().await;
1378 let cid = store.create_conversation().await.unwrap();
1379
1380 let id = store
1381 .save_message_with_parts(cid, "user", "msg", "[]")
1382 .await
1383 .unwrap();
1384
1385 let results = store.messages_by_ids(&[id]).await.unwrap();
1386 assert_eq!(results.len(), 1);
1387 assert!(
1388 results[0].1.parts.is_empty(),
1389 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1390 );
1391 }
1392
1393 #[tokio::test]
1394 async fn load_history_filtered_empty_parts_json_fast_path() {
1395 let store = test_store().await;
1396 let cid = store.create_conversation().await.unwrap();
1397
1398 store
1399 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1400 .await
1401 .unwrap();
1402
1403 let msgs = store
1404 .load_history_filtered(cid, 10, Some(true), None)
1405 .await
1406 .unwrap();
1407 assert_eq!(msgs.len(), 1);
1408 assert!(
1409 msgs[0].parts.is_empty(),
1410 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1411 );
1412 }
1413
1414 #[tokio::test]
1417 async fn keyword_search_with_time_range_empty_query_returns_empty() {
1418 let store = test_store().await;
1419 let cid = store.create_conversation().await.unwrap();
1420 store
1421 .save_message(cid, "user", "rust programming")
1422 .await
1423 .unwrap();
1424
1425 let results = store
1427 .keyword_search_with_time_range("", 10, None, None, None)
1428 .await
1429 .unwrap();
1430 assert!(results.is_empty());
1431 }
1432
1433 #[tokio::test]
1434 async fn keyword_search_with_time_range_no_bounds_matches_like_keyword_search() {
1435 let store = test_store().await;
1436 let cid = store.create_conversation().await.unwrap();
1437 store
1438 .save_message(cid, "user", "rust async programming")
1439 .await
1440 .unwrap();
1441 store
1442 .save_message(cid, "assistant", "python tutorial")
1443 .await
1444 .unwrap();
1445
1446 let results = store
1448 .keyword_search_with_time_range("rust", 10, None, None, None)
1449 .await
1450 .unwrap();
1451 assert_eq!(results.len(), 1);
1452 }
1453
1454 #[tokio::test]
1455 async fn keyword_search_with_time_range_after_bound_excludes_old_messages() {
1456 let store = test_store().await;
1457 let cid = store.create_conversation().await.unwrap();
1458
1459 store
1460 .save_message(cid, "user", "rust programming guide")
1461 .await
1462 .unwrap();
1463 store
1464 .save_message(cid, "user", "rust async patterns")
1465 .await
1466 .unwrap();
1467
1468 let results = store
1470 .keyword_search_with_time_range("rust", 10, None, Some("2099-01-01 00:00:00"), None)
1471 .await
1472 .unwrap();
1473 assert!(results.is_empty(), "no messages after year 2099");
1474 }
1475
1476 #[tokio::test]
1477 async fn keyword_search_with_time_range_before_bound_excludes_future_messages() {
1478 let store = test_store().await;
1479 let cid = store.create_conversation().await.unwrap();
1480
1481 store
1482 .save_message(cid, "user", "rust programming guide")
1483 .await
1484 .unwrap();
1485
1486 let results = store
1488 .keyword_search_with_time_range("rust", 10, None, None, Some("2000-01-01 00:00:00"))
1489 .await
1490 .unwrap();
1491 assert!(results.is_empty(), "no messages before year 2000");
1492 }
1493
1494 #[tokio::test]
1495 async fn keyword_search_with_time_range_wide_bounds_returns_results() {
1496 let store = test_store().await;
1497 let cid = store.create_conversation().await.unwrap();
1498
1499 store
1500 .save_message(cid, "user", "rust programming guide")
1501 .await
1502 .unwrap();
1503 store
1504 .save_message(cid, "assistant", "python basics")
1505 .await
1506 .unwrap();
1507
1508 let results = store
1510 .keyword_search_with_time_range(
1511 "rust",
1512 10,
1513 None,
1514 Some("2000-01-01 00:00:00"),
1515 Some("2099-12-31 23:59:59"),
1516 )
1517 .await
1518 .unwrap();
1519 assert_eq!(results.len(), 1);
1520 }
1521
1522 #[tokio::test]
1523 async fn keyword_search_with_time_range_conversation_filter() {
1524 let store = test_store().await;
1525 let cid1 = store.create_conversation().await.unwrap();
1526 let cid2 = store.create_conversation().await.unwrap();
1527
1528 store
1529 .save_message(cid1, "user", "rust memory safety")
1530 .await
1531 .unwrap();
1532 store
1533 .save_message(cid2, "user", "rust async patterns")
1534 .await
1535 .unwrap();
1536
1537 let results = store
1538 .keyword_search_with_time_range(
1539 "rust",
1540 10,
1541 Some(cid1),
1542 Some("2000-01-01 00:00:00"),
1543 Some("2099-12-31 23:59:59"),
1544 )
1545 .await
1546 .unwrap();
1547 assert_eq!(
1548 results.len(),
1549 1,
1550 "conversation filter must restrict to cid1 only"
1551 );
1552 }
1553}