1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10pub(crate) fn sanitize_fts5_query(query: &str) -> String {
20 query
21 .split(|c: char| !c.is_alphanumeric())
22 .filter(|t| !t.is_empty())
23 .collect::<Vec<_>>()
24 .join(" ")
25}
26
27fn parse_role(s: &str) -> Role {
28 match s {
29 "assistant" => Role::Assistant,
30 "system" => Role::System,
31 _ => Role::User,
32 }
33}
34
35#[must_use]
36pub fn role_str(role: Role) -> &'static str {
37 match role {
38 Role::System => "system",
39 Role::User => "user",
40 Role::Assistant => "assistant",
41 }
42}
43
44impl SqliteStore {
45 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
51 let row: (ConversationId,) =
52 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
53 .fetch_one(&self.pool)
54 .await?;
55 Ok(row.0)
56 }
57
58 pub async fn save_message(
64 &self,
65 conversation_id: ConversationId,
66 role: &str,
67 content: &str,
68 ) -> Result<MessageId, MemoryError> {
69 self.save_message_with_parts(conversation_id, role, content, "[]")
70 .await
71 }
72
73 pub async fn save_message_with_parts(
79 &self,
80 conversation_id: ConversationId,
81 role: &str,
82 content: &str,
83 parts_json: &str,
84 ) -> Result<MessageId, MemoryError> {
85 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
86 .await
87 }
88
89 pub async fn save_message_with_metadata(
95 &self,
96 conversation_id: ConversationId,
97 role: &str,
98 content: &str,
99 parts_json: &str,
100 agent_visible: bool,
101 user_visible: bool,
102 ) -> Result<MessageId, MemoryError> {
103 let row: (MessageId,) = sqlx::query_as(
104 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
105 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
106 )
107 .bind(conversation_id)
108 .bind(role)
109 .bind(content)
110 .bind(parts_json)
111 .bind(i64::from(agent_visible))
112 .bind(i64::from(user_visible))
113 .fetch_one(&self.pool)
114 .await?;
115 Ok(row.0)
116 }
117
118 pub async fn load_history(
124 &self,
125 conversation_id: ConversationId,
126 limit: u32,
127 ) -> Result<Vec<Message>, MemoryError> {
128 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
129 "SELECT role, content, parts, agent_visible, user_visible FROM (\
130 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
131 WHERE conversation_id = ? AND deleted_at IS NULL \
132 ORDER BY id DESC \
133 LIMIT ?\
134 ) ORDER BY id ASC",
135 )
136 .bind(conversation_id)
137 .bind(limit)
138 .fetch_all(&self.pool)
139 .await?;
140
141 let messages = rows
142 .into_iter()
143 .map(
144 |(role_str, content, parts_json, agent_visible, user_visible)| {
145 let parts: Vec<MessagePart> = if parts_json == "[]" {
146 vec![]
147 } else {
148 serde_json::from_str(&parts_json).unwrap_or_default()
149 };
150 Message {
151 role: parse_role(&role_str),
152 content,
153 parts,
154 metadata: MessageMetadata {
155 agent_visible: agent_visible != 0,
156 user_visible: user_visible != 0,
157 compacted_at: None,
158 deferred_summary: None,
159 },
160 }
161 },
162 )
163 .collect();
164 Ok(messages)
165 }
166
167 pub async fn load_history_filtered(
175 &self,
176 conversation_id: ConversationId,
177 limit: u32,
178 agent_visible: Option<bool>,
179 user_visible: Option<bool>,
180 ) -> Result<Vec<Message>, MemoryError> {
181 let av = agent_visible.map(i64::from);
182 let uv = user_visible.map(i64::from);
183
184 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
185 "WITH recent AS (\
186 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
187 WHERE conversation_id = ? \
188 AND deleted_at IS NULL \
189 AND (? IS NULL OR agent_visible = ?) \
190 AND (? IS NULL OR user_visible = ?) \
191 ORDER BY id DESC \
192 LIMIT ?\
193 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
194 )
195 .bind(conversation_id)
196 .bind(av)
197 .bind(av)
198 .bind(uv)
199 .bind(uv)
200 .bind(limit)
201 .fetch_all(&self.pool)
202 .await?;
203
204 let messages = rows
205 .into_iter()
206 .map(
207 |(role_str, content, parts_json, agent_visible, user_visible)| {
208 let parts: Vec<MessagePart> = if parts_json == "[]" {
209 vec![]
210 } else {
211 serde_json::from_str(&parts_json).unwrap_or_default()
212 };
213 Message {
214 role: parse_role(&role_str),
215 content,
216 parts,
217 metadata: MessageMetadata {
218 agent_visible: agent_visible != 0,
219 user_visible: user_visible != 0,
220 compacted_at: None,
221 deferred_summary: None,
222 },
223 }
224 },
225 )
226 .collect();
227 Ok(messages)
228 }
229
230 pub async fn replace_conversation(
242 &self,
243 conversation_id: ConversationId,
244 compacted_range: std::ops::RangeInclusive<MessageId>,
245 summary_role: &str,
246 summary_content: &str,
247 ) -> Result<MessageId, MemoryError> {
248 let now = {
249 let secs = std::time::SystemTime::now()
250 .duration_since(std::time::UNIX_EPOCH)
251 .unwrap_or_default()
252 .as_secs();
253 format!("{secs}")
254 };
255 let start_id = compacted_range.start().0;
256 let end_id = compacted_range.end().0;
257
258 let mut tx = self.pool.begin().await?;
259
260 sqlx::query(
261 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
262 WHERE conversation_id = ? AND id >= ? AND id <= ?",
263 )
264 .bind(&now)
265 .bind(conversation_id)
266 .bind(start_id)
267 .bind(end_id)
268 .execute(&mut *tx)
269 .await?;
270
271 let row: (MessageId,) = sqlx::query_as(
272 "INSERT INTO messages \
273 (conversation_id, role, content, parts, agent_visible, user_visible) \
274 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
275 )
276 .bind(conversation_id)
277 .bind(summary_role)
278 .bind(summary_content)
279 .fetch_one(&mut *tx)
280 .await?;
281
282 tx.commit().await?;
283
284 Ok(row.0)
285 }
286
287 pub async fn oldest_message_ids(
293 &self,
294 conversation_id: ConversationId,
295 n: u32,
296 ) -> Result<Vec<MessageId>, MemoryError> {
297 let rows: Vec<(MessageId,)> = sqlx::query_as(
298 "SELECT id FROM messages WHERE conversation_id = ? AND deleted_at IS NULL ORDER BY id ASC LIMIT ?",
299 )
300 .bind(conversation_id)
301 .bind(n)
302 .fetch_all(&self.pool)
303 .await?;
304 Ok(rows.into_iter().map(|r| r.0).collect())
305 }
306
307 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
313 let row: Option<(ConversationId,)> =
314 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
315 .fetch_optional(&self.pool)
316 .await?;
317 Ok(row.map(|r| r.0))
318 }
319
320 pub async fn message_by_id(
326 &self,
327 message_id: MessageId,
328 ) -> Result<Option<Message>, MemoryError> {
329 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
330 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ? AND deleted_at IS NULL",
331 )
332 .bind(message_id)
333 .fetch_optional(&self.pool)
334 .await?;
335
336 Ok(row.map(
337 |(role_str, content, parts_json, agent_visible, user_visible)| {
338 let parts: Vec<MessagePart> = if parts_json == "[]" {
339 vec![]
340 } else {
341 serde_json::from_str(&parts_json).unwrap_or_default()
342 };
343 Message {
344 role: parse_role(&role_str),
345 content,
346 parts,
347 metadata: MessageMetadata {
348 agent_visible: agent_visible != 0,
349 user_visible: user_visible != 0,
350 compacted_at: None,
351 deferred_summary: None,
352 },
353 }
354 },
355 ))
356 }
357
358 pub async fn messages_by_ids(
364 &self,
365 ids: &[MessageId],
366 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
367 if ids.is_empty() {
368 return Ok(Vec::new());
369 }
370
371 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
372
373 let query = format!(
374 "SELECT id, role, content, parts FROM messages \
375 WHERE id IN ({placeholders}) AND agent_visible = 1 AND deleted_at IS NULL"
376 );
377 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
378 for &id in ids {
379 q = q.bind(id);
380 }
381
382 let rows = q.fetch_all(&self.pool).await?;
383
384 Ok(rows
385 .into_iter()
386 .map(|(id, role_str, content, parts_json)| {
387 let parts: Vec<MessagePart> = if parts_json == "[]" {
388 vec![]
389 } else {
390 serde_json::from_str(&parts_json).unwrap_or_default()
391 };
392 (
393 id,
394 Message {
395 role: parse_role(&role_str),
396 content,
397 parts,
398 metadata: MessageMetadata::default(),
399 },
400 )
401 })
402 .collect())
403 }
404
405 pub async fn unembedded_message_ids(
411 &self,
412 limit: Option<usize>,
413 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
414 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
415
416 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
417 "SELECT m.id, m.conversation_id, m.role, m.content \
418 FROM messages m \
419 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
420 WHERE em.id IS NULL AND m.deleted_at IS NULL \
421 ORDER BY m.id ASC \
422 LIMIT ?",
423 )
424 .bind(effective_limit)
425 .fetch_all(&self.pool)
426 .await?;
427
428 Ok(rows)
429 }
430
431 pub async fn count_messages(
437 &self,
438 conversation_id: ConversationId,
439 ) -> Result<i64, MemoryError> {
440 let row: (i64,) = sqlx::query_as(
441 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND deleted_at IS NULL",
442 )
443 .bind(conversation_id)
444 .fetch_one(&self.pool)
445 .await?;
446 Ok(row.0)
447 }
448
449 pub async fn count_messages_after(
455 &self,
456 conversation_id: ConversationId,
457 after_id: MessageId,
458 ) -> Result<i64, MemoryError> {
459 let row: (i64,) =
460 sqlx::query_as(
461 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL",
462 )
463 .bind(conversation_id)
464 .bind(after_id)
465 .fetch_one(&self.pool)
466 .await?;
467 Ok(row.0)
468 }
469
470 pub async fn keyword_search(
479 &self,
480 query: &str,
481 limit: usize,
482 conversation_id: Option<ConversationId>,
483 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
484 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
485 let safe_query = sanitize_fts5_query(query);
486 if safe_query.is_empty() {
487 return Ok(Vec::new());
488 }
489
490 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
491 sqlx::query_as(
492 "SELECT m.id, -rank AS score \
493 FROM messages_fts f \
494 JOIN messages m ON m.id = f.rowid \
495 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
496 ORDER BY rank \
497 LIMIT ?",
498 )
499 .bind(&safe_query)
500 .bind(cid)
501 .bind(effective_limit)
502 .fetch_all(&self.pool)
503 .await?
504 } else {
505 sqlx::query_as(
506 "SELECT m.id, -rank AS score \
507 FROM messages_fts f \
508 JOIN messages m ON m.id = f.rowid \
509 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
510 ORDER BY rank \
511 LIMIT ?",
512 )
513 .bind(&safe_query)
514 .bind(effective_limit)
515 .fetch_all(&self.pool)
516 .await?
517 };
518
519 Ok(rows)
520 }
521
522 pub async fn message_timestamps(
530 &self,
531 ids: &[MessageId],
532 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
533 if ids.is_empty() {
534 return Ok(std::collections::HashMap::new());
535 }
536
537 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
538 let query = format!(
539 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
540 FROM messages WHERE id IN ({placeholders}) AND deleted_at IS NULL"
541 );
542 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
543 for &id in ids {
544 q = q.bind(id);
545 }
546
547 let rows = q.fetch_all(&self.pool).await?;
548 Ok(rows.into_iter().collect())
549 }
550
551 pub async fn load_messages_range(
557 &self,
558 conversation_id: ConversationId,
559 after_message_id: MessageId,
560 limit: usize,
561 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
562 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
563
564 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
565 "SELECT id, role, content FROM messages \
566 WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL \
567 ORDER BY id ASC LIMIT ?",
568 )
569 .bind(conversation_id)
570 .bind(after_message_id)
571 .bind(effective_limit)
572 .fetch_all(&self.pool)
573 .await?;
574
575 Ok(rows)
576 }
577
578 pub async fn get_eviction_candidates(
586 &self,
587 ) -> Result<Vec<crate::eviction::EvictionEntry>, crate::error::MemoryError> {
588 let rows: Vec<(MessageId, String, Option<String>, i64)> = sqlx::query_as(
589 "SELECT id, created_at, last_accessed, access_count \
590 FROM messages WHERE deleted_at IS NULL",
591 )
592 .fetch_all(&self.pool)
593 .await?;
594
595 Ok(rows
596 .into_iter()
597 .map(
598 |(id, created_at, last_accessed, access_count)| crate::eviction::EvictionEntry {
599 id,
600 created_at,
601 last_accessed,
602 access_count: access_count.try_into().unwrap_or(0),
603 },
604 )
605 .collect())
606 }
607
608 pub async fn soft_delete_messages(
616 &self,
617 ids: &[MessageId],
618 ) -> Result<(), crate::error::MemoryError> {
619 if ids.is_empty() {
620 return Ok(());
621 }
622 for &id in ids {
624 sqlx::query(
625 "UPDATE messages SET deleted_at = datetime('now') WHERE id = ? AND deleted_at IS NULL",
626 )
627 .bind(id)
628 .execute(&self.pool)
629 .await?;
630 }
631 Ok(())
632 }
633
634 pub async fn get_soft_deleted_message_ids(
640 &self,
641 ) -> Result<Vec<MessageId>, crate::error::MemoryError> {
642 let rows: Vec<(MessageId,)> = sqlx::query_as(
643 "SELECT id FROM messages WHERE deleted_at IS NOT NULL AND qdrant_cleaned = 0",
644 )
645 .fetch_all(&self.pool)
646 .await?;
647 Ok(rows.into_iter().map(|(id,)| id).collect())
648 }
649
650 pub async fn mark_qdrant_cleaned(
656 &self,
657 ids: &[MessageId],
658 ) -> Result<(), crate::error::MemoryError> {
659 for &id in ids {
660 sqlx::query("UPDATE messages SET qdrant_cleaned = 1 WHERE id = ?")
661 .bind(id)
662 .execute(&self.pool)
663 .await?;
664 }
665 Ok(())
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 async fn test_store() -> SqliteStore {
674 SqliteStore::new(":memory:").await.unwrap()
675 }
676
677 #[tokio::test]
678 async fn create_conversation_returns_id() {
679 let store = test_store().await;
680 let id1 = store.create_conversation().await.unwrap();
681 let id2 = store.create_conversation().await.unwrap();
682 assert_eq!(id1, ConversationId(1));
683 assert_eq!(id2, ConversationId(2));
684 }
685
686 #[tokio::test]
687 async fn save_and_load_messages() {
688 let store = test_store().await;
689 let cid = store.create_conversation().await.unwrap();
690
691 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
692 let msg_id2 = store
693 .save_message(cid, "assistant", "hi there")
694 .await
695 .unwrap();
696
697 assert_eq!(msg_id1, MessageId(1));
698 assert_eq!(msg_id2, MessageId(2));
699
700 let history = store.load_history(cid, 50).await.unwrap();
701 assert_eq!(history.len(), 2);
702 assert_eq!(history[0].role, Role::User);
703 assert_eq!(history[0].content, "hello");
704 assert_eq!(history[1].role, Role::Assistant);
705 assert_eq!(history[1].content, "hi there");
706 }
707
708 #[tokio::test]
709 async fn load_history_respects_limit() {
710 let store = test_store().await;
711 let cid = store.create_conversation().await.unwrap();
712
713 for i in 0..10 {
714 store
715 .save_message(cid, "user", &format!("msg {i}"))
716 .await
717 .unwrap();
718 }
719
720 let history = store.load_history(cid, 3).await.unwrap();
721 assert_eq!(history.len(), 3);
722 assert_eq!(history[0].content, "msg 7");
723 assert_eq!(history[1].content, "msg 8");
724 assert_eq!(history[2].content, "msg 9");
725 }
726
727 #[tokio::test]
728 async fn latest_conversation_id_empty() {
729 let store = test_store().await;
730 assert!(store.latest_conversation_id().await.unwrap().is_none());
731 }
732
733 #[tokio::test]
734 async fn latest_conversation_id_returns_newest() {
735 let store = test_store().await;
736 store.create_conversation().await.unwrap();
737 let id2 = store.create_conversation().await.unwrap();
738 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
739 }
740
741 #[tokio::test]
742 async fn messages_isolated_per_conversation() {
743 let store = test_store().await;
744 let cid1 = store.create_conversation().await.unwrap();
745 let cid2 = store.create_conversation().await.unwrap();
746
747 store.save_message(cid1, "user", "conv1").await.unwrap();
748 store.save_message(cid2, "user", "conv2").await.unwrap();
749
750 let h1 = store.load_history(cid1, 50).await.unwrap();
751 let h2 = store.load_history(cid2, 50).await.unwrap();
752 assert_eq!(h1.len(), 1);
753 assert_eq!(h1[0].content, "conv1");
754 assert_eq!(h2.len(), 1);
755 assert_eq!(h2[0].content, "conv2");
756 }
757
758 #[tokio::test]
759 async fn pool_accessor_returns_valid_pool() {
760 let store = test_store().await;
761 let pool = store.pool();
762 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
763 assert_eq!(row.0, 1);
764 }
765
766 #[tokio::test]
767 async fn embeddings_metadata_table_exists() {
768 let store = test_store().await;
769 let result: (i64,) = sqlx::query_as(
770 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
771 )
772 .fetch_one(store.pool())
773 .await
774 .unwrap();
775 assert_eq!(result.0, 1);
776 }
777
778 #[tokio::test]
779 async fn cascade_delete_removes_embeddings_metadata() {
780 let store = test_store().await;
781 let pool = store.pool();
782
783 let cid = store.create_conversation().await.unwrap();
784 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
785
786 let point_id = uuid::Uuid::new_v4().to_string();
787 sqlx::query(
788 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
789 VALUES (?, ?, ?)",
790 )
791 .bind(msg_id)
792 .bind(&point_id)
793 .bind(768_i64)
794 .execute(pool)
795 .await
796 .unwrap();
797
798 let before: (i64,) =
799 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
800 .bind(msg_id)
801 .fetch_one(pool)
802 .await
803 .unwrap();
804 assert_eq!(before.0, 1);
805
806 sqlx::query("DELETE FROM messages WHERE id = ?")
807 .bind(msg_id)
808 .execute(pool)
809 .await
810 .unwrap();
811
812 let after: (i64,) =
813 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
814 .bind(msg_id)
815 .fetch_one(pool)
816 .await
817 .unwrap();
818 assert_eq!(after.0, 0);
819 }
820
821 #[tokio::test]
822 async fn messages_by_ids_batch_fetch() {
823 let store = test_store().await;
824 let cid = store.create_conversation().await.unwrap();
825 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
826 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
827 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
828
829 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
830 assert_eq!(results.len(), 2);
831 assert_eq!(results[0].0, id1);
832 assert_eq!(results[0].1.content, "hello");
833 assert_eq!(results[1].0, id2);
834 assert_eq!(results[1].1.content, "hi");
835 }
836
837 #[tokio::test]
838 async fn messages_by_ids_empty_input() {
839 let store = test_store().await;
840 let results = store.messages_by_ids(&[]).await.unwrap();
841 assert!(results.is_empty());
842 }
843
844 #[tokio::test]
845 async fn messages_by_ids_nonexistent() {
846 let store = test_store().await;
847 let results = store
848 .messages_by_ids(&[MessageId(999), MessageId(1000)])
849 .await
850 .unwrap();
851 assert!(results.is_empty());
852 }
853
854 #[tokio::test]
855 async fn message_by_id_fetches_existing() {
856 let store = test_store().await;
857 let cid = store.create_conversation().await.unwrap();
858 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
859
860 let msg = store.message_by_id(msg_id).await.unwrap();
861 assert!(msg.is_some());
862 let msg = msg.unwrap();
863 assert_eq!(msg.role, Role::User);
864 assert_eq!(msg.content, "hello");
865 }
866
867 #[tokio::test]
868 async fn message_by_id_returns_none_for_nonexistent() {
869 let store = test_store().await;
870 let msg = store.message_by_id(MessageId(999)).await.unwrap();
871 assert!(msg.is_none());
872 }
873
874 #[tokio::test]
875 async fn unembedded_message_ids_returns_all_when_none_embedded() {
876 let store = test_store().await;
877 let cid = store.create_conversation().await.unwrap();
878
879 store.save_message(cid, "user", "msg1").await.unwrap();
880 store.save_message(cid, "assistant", "msg2").await.unwrap();
881
882 let unembedded = store.unembedded_message_ids(None).await.unwrap();
883 assert_eq!(unembedded.len(), 2);
884 assert_eq!(unembedded[0].3, "msg1");
885 assert_eq!(unembedded[1].3, "msg2");
886 }
887
888 #[tokio::test]
889 async fn unembedded_message_ids_excludes_embedded() {
890 let store = test_store().await;
891 let pool = store.pool();
892 let cid = store.create_conversation().await.unwrap();
893
894 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
895 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
896
897 let point_id = uuid::Uuid::new_v4().to_string();
898 sqlx::query(
899 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
900 VALUES (?, ?, ?)",
901 )
902 .bind(msg_id1)
903 .bind(&point_id)
904 .bind(768_i64)
905 .execute(pool)
906 .await
907 .unwrap();
908
909 let unembedded = store.unembedded_message_ids(None).await.unwrap();
910 assert_eq!(unembedded.len(), 1);
911 assert_eq!(unembedded[0].0, msg_id2);
912 assert_eq!(unembedded[0].3, "msg2");
913 }
914
915 #[tokio::test]
916 async fn unembedded_message_ids_respects_limit() {
917 let store = test_store().await;
918 let cid = store.create_conversation().await.unwrap();
919
920 for i in 0..10 {
921 store
922 .save_message(cid, "user", &format!("msg{i}"))
923 .await
924 .unwrap();
925 }
926
927 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
928 assert_eq!(unembedded.len(), 3);
929 }
930
931 #[tokio::test]
932 async fn count_messages_returns_correct_count() {
933 let store = test_store().await;
934 let cid = store.create_conversation().await.unwrap();
935
936 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
937
938 store.save_message(cid, "user", "msg1").await.unwrap();
939 store.save_message(cid, "assistant", "msg2").await.unwrap();
940
941 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
942 }
943
944 #[tokio::test]
945 async fn count_messages_after_filters_correctly() {
946 let store = test_store().await;
947 let cid = store.create_conversation().await.unwrap();
948
949 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
950 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
951 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
952
953 assert_eq!(
954 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
955 3
956 );
957 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
958 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
959 }
960
961 #[tokio::test]
962 async fn load_messages_range_basic() {
963 let store = test_store().await;
964 let cid = store.create_conversation().await.unwrap();
965
966 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
967 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
968 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
969
970 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
971 assert_eq!(msgs.len(), 2);
972 assert_eq!(msgs[0].0, msg_id2);
973 assert_eq!(msgs[0].2, "msg2");
974 assert_eq!(msgs[1].0, msg_id3);
975 assert_eq!(msgs[1].2, "msg3");
976 }
977
978 #[tokio::test]
979 async fn load_messages_range_respects_limit() {
980 let store = test_store().await;
981 let cid = store.create_conversation().await.unwrap();
982
983 store.save_message(cid, "user", "msg1").await.unwrap();
984 store.save_message(cid, "assistant", "msg2").await.unwrap();
985 store.save_message(cid, "user", "msg3").await.unwrap();
986
987 let msgs = store
988 .load_messages_range(cid, MessageId(0), 2)
989 .await
990 .unwrap();
991 assert_eq!(msgs.len(), 2);
992 }
993
994 #[tokio::test]
995 async fn keyword_search_basic() {
996 let store = test_store().await;
997 let cid = store.create_conversation().await.unwrap();
998
999 store
1000 .save_message(cid, "user", "rust programming language")
1001 .await
1002 .unwrap();
1003 store
1004 .save_message(cid, "assistant", "python is great too")
1005 .await
1006 .unwrap();
1007 store
1008 .save_message(cid, "user", "I love rust and cargo")
1009 .await
1010 .unwrap();
1011
1012 let results = store.keyword_search("rust", 10, None).await.unwrap();
1013 assert_eq!(results.len(), 2);
1014 assert!(results.iter().all(|(_, score)| *score > 0.0));
1015 }
1016
1017 #[tokio::test]
1018 async fn keyword_search_with_conversation_filter() {
1019 let store = test_store().await;
1020 let cid1 = store.create_conversation().await.unwrap();
1021 let cid2 = store.create_conversation().await.unwrap();
1022
1023 store
1024 .save_message(cid1, "user", "hello world")
1025 .await
1026 .unwrap();
1027 store
1028 .save_message(cid2, "user", "hello universe")
1029 .await
1030 .unwrap();
1031
1032 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
1033 assert_eq!(results.len(), 1);
1034 }
1035
1036 #[tokio::test]
1037 async fn keyword_search_no_match() {
1038 let store = test_store().await;
1039 let cid = store.create_conversation().await.unwrap();
1040
1041 store
1042 .save_message(cid, "user", "hello world")
1043 .await
1044 .unwrap();
1045
1046 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
1047 assert!(results.is_empty());
1048 }
1049
1050 #[tokio::test]
1051 async fn keyword_search_respects_limit() {
1052 let store = test_store().await;
1053 let cid = store.create_conversation().await.unwrap();
1054
1055 for i in 0..10 {
1056 store
1057 .save_message(cid, "user", &format!("test message {i}"))
1058 .await
1059 .unwrap();
1060 }
1061
1062 let results = store.keyword_search("test", 3, None).await.unwrap();
1063 assert_eq!(results.len(), 3);
1064 }
1065
1066 #[test]
1067 fn sanitize_fts5_query_strips_special_chars() {
1068 assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
1069 assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
1070 assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
1071 assert_eq!(sanitize_fts5_query(" "), "");
1072 assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
1073 }
1074
1075 #[tokio::test]
1076 async fn keyword_search_with_special_chars_does_not_error() {
1077 let store = test_store().await;
1078 let cid = store.create_conversation().await.unwrap();
1079 store
1080 .save_message(cid, "user", "skill audit info")
1081 .await
1082 .unwrap();
1083 store
1086 .keyword_search("skill-audit, confidence=0.1", 10, None)
1087 .await
1088 .unwrap();
1089 }
1090
1091 #[tokio::test]
1092 async fn save_message_with_metadata_stores_visibility() {
1093 let store = test_store().await;
1094 let cid = store.create_conversation().await.unwrap();
1095
1096 let id = store
1097 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
1098 .await
1099 .unwrap();
1100
1101 let history = store.load_history(cid, 10).await.unwrap();
1102 assert_eq!(history.len(), 1);
1103 assert!(!history[0].metadata.agent_visible);
1104 assert!(history[0].metadata.user_visible);
1105 assert_eq!(id, MessageId(1));
1106 }
1107
1108 #[tokio::test]
1109 async fn load_history_filtered_by_agent_visible() {
1110 let store = test_store().await;
1111 let cid = store.create_conversation().await.unwrap();
1112
1113 store
1114 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1115 .await
1116 .unwrap();
1117 store
1118 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1119 .await
1120 .unwrap();
1121
1122 let agent_msgs = store
1123 .load_history_filtered(cid, 50, Some(true), None)
1124 .await
1125 .unwrap();
1126 assert_eq!(agent_msgs.len(), 1);
1127 assert_eq!(agent_msgs[0].content, "visible to agent");
1128 }
1129
1130 #[tokio::test]
1131 async fn load_history_filtered_by_user_visible() {
1132 let store = test_store().await;
1133 let cid = store.create_conversation().await.unwrap();
1134
1135 store
1136 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1137 .await
1138 .unwrap();
1139 store
1140 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1141 .await
1142 .unwrap();
1143
1144 let user_msgs = store
1145 .load_history_filtered(cid, 50, None, Some(true))
1146 .await
1147 .unwrap();
1148 assert_eq!(user_msgs.len(), 1);
1149 assert_eq!(user_msgs[0].content, "user sees this");
1150 }
1151
1152 #[tokio::test]
1153 async fn load_history_filtered_no_filter_returns_all() {
1154 let store = test_store().await;
1155 let cid = store.create_conversation().await.unwrap();
1156
1157 store
1158 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1159 .await
1160 .unwrap();
1161 store
1162 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1163 .await
1164 .unwrap();
1165
1166 let all_msgs = store
1167 .load_history_filtered(cid, 50, None, None)
1168 .await
1169 .unwrap();
1170 assert_eq!(all_msgs.len(), 2);
1171 }
1172
1173 #[tokio::test]
1174 async fn replace_conversation_marks_originals_and_inserts_summary() {
1175 let store = test_store().await;
1176 let cid = store.create_conversation().await.unwrap();
1177
1178 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1179 let id2 = store
1180 .save_message(cid, "assistant", "second")
1181 .await
1182 .unwrap();
1183 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1184
1185 let summary_id = store
1186 .replace_conversation(cid, id1..=id2, "system", "summary text")
1187 .await
1188 .unwrap();
1189
1190 let all = store.load_history(cid, 50).await.unwrap();
1192 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1194 assert!(!by_id1.metadata.agent_visible);
1195 assert!(by_id1.metadata.user_visible);
1196
1197 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1198 assert!(!by_id2.metadata.agent_visible);
1199
1200 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1201 assert!(by_id3.metadata.agent_visible);
1202
1203 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1205 assert!(summary.metadata.agent_visible);
1206 assert!(!summary.metadata.user_visible);
1207 assert!(summary_id > id3);
1208 }
1209
1210 #[tokio::test]
1211 async fn oldest_message_ids_returns_in_order() {
1212 let store = test_store().await;
1213 let cid = store.create_conversation().await.unwrap();
1214
1215 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1216 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1217 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1218
1219 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1220 assert_eq!(ids, vec![id1, id2]);
1221 assert!(ids[0] < ids[1]);
1222
1223 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1224 assert_eq!(all_ids, vec![id1, id2, id3]);
1225 }
1226
1227 #[tokio::test]
1228 async fn message_metadata_default_both_visible() {
1229 let store = test_store().await;
1230 let cid = store.create_conversation().await.unwrap();
1231
1232 store.save_message(cid, "user", "normal").await.unwrap();
1233
1234 let history = store.load_history(cid, 10).await.unwrap();
1235 assert!(history[0].metadata.agent_visible);
1236 assert!(history[0].metadata.user_visible);
1237 assert!(history[0].metadata.compacted_at.is_none());
1238 }
1239
1240 #[tokio::test]
1241 async fn load_history_empty_parts_json_fast_path() {
1242 let store = test_store().await;
1243 let cid = store.create_conversation().await.unwrap();
1244
1245 store
1246 .save_message_with_parts(cid, "user", "hello", "[]")
1247 .await
1248 .unwrap();
1249
1250 let history = store.load_history(cid, 10).await.unwrap();
1251 assert_eq!(history.len(), 1);
1252 assert!(
1253 history[0].parts.is_empty(),
1254 "\"[]\" fast-path must yield empty parts Vec"
1255 );
1256 }
1257
1258 #[tokio::test]
1259 async fn load_history_non_empty_parts_json_parsed() {
1260 let store = test_store().await;
1261 let cid = store.create_conversation().await.unwrap();
1262
1263 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1264 tool_use_id: "t1".into(),
1265 content: "result".into(),
1266 is_error: false,
1267 }])
1268 .unwrap();
1269
1270 store
1271 .save_message_with_parts(cid, "user", "hello", &parts_json)
1272 .await
1273 .unwrap();
1274
1275 let history = store.load_history(cid, 10).await.unwrap();
1276 assert_eq!(history.len(), 1);
1277 assert_eq!(history[0].parts.len(), 1);
1278 assert!(
1279 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1280 );
1281 }
1282
1283 #[tokio::test]
1284 async fn message_by_id_empty_parts_json_fast_path() {
1285 let store = test_store().await;
1286 let cid = store.create_conversation().await.unwrap();
1287
1288 let id = store
1289 .save_message_with_parts(cid, "user", "msg", "[]")
1290 .await
1291 .unwrap();
1292
1293 let msg = store.message_by_id(id).await.unwrap().unwrap();
1294 assert!(
1295 msg.parts.is_empty(),
1296 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1297 );
1298 }
1299
1300 #[tokio::test]
1301 async fn messages_by_ids_empty_parts_json_fast_path() {
1302 let store = test_store().await;
1303 let cid = store.create_conversation().await.unwrap();
1304
1305 let id = store
1306 .save_message_with_parts(cid, "user", "msg", "[]")
1307 .await
1308 .unwrap();
1309
1310 let results = store.messages_by_ids(&[id]).await.unwrap();
1311 assert_eq!(results.len(), 1);
1312 assert!(
1313 results[0].1.parts.is_empty(),
1314 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1315 );
1316 }
1317
1318 #[tokio::test]
1319 async fn load_history_filtered_empty_parts_json_fast_path() {
1320 let store = test_store().await;
1321 let cid = store.create_conversation().await.unwrap();
1322
1323 store
1324 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1325 .await
1326 .unwrap();
1327
1328 let msgs = store
1329 .load_history_filtered(cid, 10, Some(true), None)
1330 .await
1331 .unwrap();
1332 assert_eq!(msgs.len(), 1);
1333 assert!(
1334 msgs[0].parts.is_empty(),
1335 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1336 );
1337 }
1338}