1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10pub(crate) fn sanitize_fts5_query(query: &str) -> String {
20 query
21 .split(|c: char| !c.is_alphanumeric())
22 .filter(|t| !t.is_empty())
23 .collect::<Vec<_>>()
24 .join(" ")
25}
26
27fn parse_role(s: &str) -> Role {
28 match s {
29 "assistant" => Role::Assistant,
30 "system" => Role::System,
31 _ => Role::User,
32 }
33}
34
35#[must_use]
36pub fn role_str(role: Role) -> &'static str {
37 match role {
38 Role::System => "system",
39 Role::User => "user",
40 Role::Assistant => "assistant",
41 }
42}
43
44fn parse_parts_json(role_str: &str, parts_json: &str) -> Vec<MessagePart> {
49 if parts_json == "[]" {
50 return vec![];
51 }
52 match serde_json::from_str(parts_json) {
53 Ok(p) => p,
54 Err(e) => {
55 let truncated = parts_json.chars().take(120).collect::<String>();
56 tracing::warn!(
57 role = %role_str,
58 parts_json = %truncated,
59 error = %e,
60 "failed to deserialize message parts, falling back to empty"
61 );
62 vec![]
63 }
64 }
65}
66
67impl SqliteStore {
68 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
74 let row: (ConversationId,) =
75 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
76 .fetch_one(&self.pool)
77 .await?;
78 Ok(row.0)
79 }
80
81 pub async fn save_message(
87 &self,
88 conversation_id: ConversationId,
89 role: &str,
90 content: &str,
91 ) -> Result<MessageId, MemoryError> {
92 self.save_message_with_parts(conversation_id, role, content, "[]")
93 .await
94 }
95
96 pub async fn save_message_with_parts(
102 &self,
103 conversation_id: ConversationId,
104 role: &str,
105 content: &str,
106 parts_json: &str,
107 ) -> Result<MessageId, MemoryError> {
108 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
109 .await
110 }
111
112 pub async fn save_message_with_metadata(
118 &self,
119 conversation_id: ConversationId,
120 role: &str,
121 content: &str,
122 parts_json: &str,
123 agent_visible: bool,
124 user_visible: bool,
125 ) -> Result<MessageId, MemoryError> {
126 let row: (MessageId,) = sqlx::query_as(
127 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
128 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
129 )
130 .bind(conversation_id)
131 .bind(role)
132 .bind(content)
133 .bind(parts_json)
134 .bind(i64::from(agent_visible))
135 .bind(i64::from(user_visible))
136 .fetch_one(&self.pool)
137 .await?;
138 Ok(row.0)
139 }
140
141 pub async fn load_history(
147 &self,
148 conversation_id: ConversationId,
149 limit: u32,
150 ) -> Result<Vec<Message>, MemoryError> {
151 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
152 "SELECT role, content, parts, agent_visible, user_visible FROM (\
153 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
154 WHERE conversation_id = ? AND deleted_at IS NULL \
155 ORDER BY id DESC \
156 LIMIT ?\
157 ) ORDER BY id ASC",
158 )
159 .bind(conversation_id)
160 .bind(limit)
161 .fetch_all(&self.pool)
162 .await?;
163
164 let messages = rows
165 .into_iter()
166 .map(
167 |(role_str, content, parts_json, agent_visible, user_visible)| {
168 let parts = parse_parts_json(&role_str, &parts_json);
169 Message {
170 role: parse_role(&role_str),
171 content,
172 parts,
173 metadata: MessageMetadata {
174 agent_visible: agent_visible != 0,
175 user_visible: user_visible != 0,
176 compacted_at: None,
177 deferred_summary: None,
178 },
179 }
180 },
181 )
182 .collect();
183 Ok(messages)
184 }
185
186 pub async fn load_history_filtered(
194 &self,
195 conversation_id: ConversationId,
196 limit: u32,
197 agent_visible: Option<bool>,
198 user_visible: Option<bool>,
199 ) -> Result<Vec<Message>, MemoryError> {
200 let av = agent_visible.map(i64::from);
201 let uv = user_visible.map(i64::from);
202
203 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
204 "WITH recent AS (\
205 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
206 WHERE conversation_id = ? \
207 AND deleted_at IS NULL \
208 AND (? IS NULL OR agent_visible = ?) \
209 AND (? IS NULL OR user_visible = ?) \
210 ORDER BY id DESC \
211 LIMIT ?\
212 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
213 )
214 .bind(conversation_id)
215 .bind(av)
216 .bind(av)
217 .bind(uv)
218 .bind(uv)
219 .bind(limit)
220 .fetch_all(&self.pool)
221 .await?;
222
223 let messages = rows
224 .into_iter()
225 .map(
226 |(role_str, content, parts_json, agent_visible, user_visible)| {
227 let parts = parse_parts_json(&role_str, &parts_json);
228 Message {
229 role: parse_role(&role_str),
230 content,
231 parts,
232 metadata: MessageMetadata {
233 agent_visible: agent_visible != 0,
234 user_visible: user_visible != 0,
235 compacted_at: None,
236 deferred_summary: None,
237 },
238 }
239 },
240 )
241 .collect();
242 Ok(messages)
243 }
244
245 pub async fn replace_conversation(
257 &self,
258 conversation_id: ConversationId,
259 compacted_range: std::ops::RangeInclusive<MessageId>,
260 summary_role: &str,
261 summary_content: &str,
262 ) -> Result<MessageId, MemoryError> {
263 let now = {
264 let secs = std::time::SystemTime::now()
265 .duration_since(std::time::UNIX_EPOCH)
266 .unwrap_or_default()
267 .as_secs();
268 format!("{secs}")
269 };
270 let start_id = compacted_range.start().0;
271 let end_id = compacted_range.end().0;
272
273 let mut tx = self.pool.begin().await?;
274
275 sqlx::query(
276 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
277 WHERE conversation_id = ? AND id >= ? AND id <= ?",
278 )
279 .bind(&now)
280 .bind(conversation_id)
281 .bind(start_id)
282 .bind(end_id)
283 .execute(&mut *tx)
284 .await?;
285
286 let row: (MessageId,) = sqlx::query_as(
287 "INSERT INTO messages \
288 (conversation_id, role, content, parts, agent_visible, user_visible) \
289 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
290 )
291 .bind(conversation_id)
292 .bind(summary_role)
293 .bind(summary_content)
294 .fetch_one(&mut *tx)
295 .await?;
296
297 tx.commit().await?;
298
299 Ok(row.0)
300 }
301
302 pub async fn oldest_message_ids(
308 &self,
309 conversation_id: ConversationId,
310 n: u32,
311 ) -> Result<Vec<MessageId>, MemoryError> {
312 let rows: Vec<(MessageId,)> = sqlx::query_as(
313 "SELECT id FROM messages WHERE conversation_id = ? AND deleted_at IS NULL ORDER BY id ASC LIMIT ?",
314 )
315 .bind(conversation_id)
316 .bind(n)
317 .fetch_all(&self.pool)
318 .await?;
319 Ok(rows.into_iter().map(|r| r.0).collect())
320 }
321
322 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
328 let row: Option<(ConversationId,)> =
329 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
330 .fetch_optional(&self.pool)
331 .await?;
332 Ok(row.map(|r| r.0))
333 }
334
335 pub async fn message_by_id(
341 &self,
342 message_id: MessageId,
343 ) -> Result<Option<Message>, MemoryError> {
344 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
345 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ? AND deleted_at IS NULL",
346 )
347 .bind(message_id)
348 .fetch_optional(&self.pool)
349 .await?;
350
351 Ok(row.map(
352 |(role_str, content, parts_json, agent_visible, user_visible)| {
353 let parts = parse_parts_json(&role_str, &parts_json);
354 Message {
355 role: parse_role(&role_str),
356 content,
357 parts,
358 metadata: MessageMetadata {
359 agent_visible: agent_visible != 0,
360 user_visible: user_visible != 0,
361 compacted_at: None,
362 deferred_summary: None,
363 },
364 }
365 },
366 ))
367 }
368
369 pub async fn messages_by_ids(
375 &self,
376 ids: &[MessageId],
377 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
378 if ids.is_empty() {
379 return Ok(Vec::new());
380 }
381
382 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
383
384 let query = format!(
385 "SELECT id, role, content, parts FROM messages \
386 WHERE id IN ({placeholders}) AND agent_visible = 1 AND deleted_at IS NULL"
387 );
388 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
389 for &id in ids {
390 q = q.bind(id);
391 }
392
393 let rows = q.fetch_all(&self.pool).await?;
394
395 Ok(rows
396 .into_iter()
397 .map(|(id, role_str, content, parts_json)| {
398 let parts = parse_parts_json(&role_str, &parts_json);
399 (
400 id,
401 Message {
402 role: parse_role(&role_str),
403 content,
404 parts,
405 metadata: MessageMetadata::default(),
406 },
407 )
408 })
409 .collect())
410 }
411
412 pub async fn unembedded_message_ids(
418 &self,
419 limit: Option<usize>,
420 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
421 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
422
423 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
424 "SELECT m.id, m.conversation_id, m.role, m.content \
425 FROM messages m \
426 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
427 WHERE em.id IS NULL AND m.deleted_at IS NULL \
428 ORDER BY m.id ASC \
429 LIMIT ?",
430 )
431 .bind(effective_limit)
432 .fetch_all(&self.pool)
433 .await?;
434
435 Ok(rows)
436 }
437
438 pub async fn count_messages(
444 &self,
445 conversation_id: ConversationId,
446 ) -> Result<i64, MemoryError> {
447 let row: (i64,) = sqlx::query_as(
448 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND deleted_at IS NULL",
449 )
450 .bind(conversation_id)
451 .fetch_one(&self.pool)
452 .await?;
453 Ok(row.0)
454 }
455
456 pub async fn count_messages_after(
462 &self,
463 conversation_id: ConversationId,
464 after_id: MessageId,
465 ) -> Result<i64, MemoryError> {
466 let row: (i64,) =
467 sqlx::query_as(
468 "SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL",
469 )
470 .bind(conversation_id)
471 .bind(after_id)
472 .fetch_one(&self.pool)
473 .await?;
474 Ok(row.0)
475 }
476
477 pub async fn keyword_search(
486 &self,
487 query: &str,
488 limit: usize,
489 conversation_id: Option<ConversationId>,
490 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
491 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
492 let safe_query = sanitize_fts5_query(query);
493 if safe_query.is_empty() {
494 return Ok(Vec::new());
495 }
496
497 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
498 sqlx::query_as(
499 "SELECT m.id, -rank AS score \
500 FROM messages_fts f \
501 JOIN messages m ON m.id = f.rowid \
502 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
503 ORDER BY rank \
504 LIMIT ?",
505 )
506 .bind(&safe_query)
507 .bind(cid)
508 .bind(effective_limit)
509 .fetch_all(&self.pool)
510 .await?
511 } else {
512 sqlx::query_as(
513 "SELECT m.id, -rank AS score \
514 FROM messages_fts f \
515 JOIN messages m ON m.id = f.rowid \
516 WHERE messages_fts MATCH ? AND m.agent_visible = 1 AND m.deleted_at IS NULL \
517 ORDER BY rank \
518 LIMIT ?",
519 )
520 .bind(&safe_query)
521 .bind(effective_limit)
522 .fetch_all(&self.pool)
523 .await?
524 };
525
526 Ok(rows)
527 }
528
529 pub async fn message_timestamps(
537 &self,
538 ids: &[MessageId],
539 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
540 if ids.is_empty() {
541 return Ok(std::collections::HashMap::new());
542 }
543
544 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
545 let query = format!(
546 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
547 FROM messages WHERE id IN ({placeholders}) AND deleted_at IS NULL"
548 );
549 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
550 for &id in ids {
551 q = q.bind(id);
552 }
553
554 let rows = q.fetch_all(&self.pool).await?;
555 Ok(rows.into_iter().collect())
556 }
557
558 pub async fn load_messages_range(
564 &self,
565 conversation_id: ConversationId,
566 after_message_id: MessageId,
567 limit: usize,
568 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
569 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
570
571 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
572 "SELECT id, role, content FROM messages \
573 WHERE conversation_id = ? AND id > ? AND deleted_at IS NULL \
574 ORDER BY id ASC LIMIT ?",
575 )
576 .bind(conversation_id)
577 .bind(after_message_id)
578 .bind(effective_limit)
579 .fetch_all(&self.pool)
580 .await?;
581
582 Ok(rows)
583 }
584
585 pub async fn get_eviction_candidates(
593 &self,
594 ) -> Result<Vec<crate::eviction::EvictionEntry>, crate::error::MemoryError> {
595 let rows: Vec<(MessageId, String, Option<String>, i64)> = sqlx::query_as(
596 "SELECT id, created_at, last_accessed, access_count \
597 FROM messages WHERE deleted_at IS NULL",
598 )
599 .fetch_all(&self.pool)
600 .await?;
601
602 Ok(rows
603 .into_iter()
604 .map(
605 |(id, created_at, last_accessed, access_count)| crate::eviction::EvictionEntry {
606 id,
607 created_at,
608 last_accessed,
609 access_count: access_count.try_into().unwrap_or(0),
610 },
611 )
612 .collect())
613 }
614
615 pub async fn soft_delete_messages(
623 &self,
624 ids: &[MessageId],
625 ) -> Result<(), crate::error::MemoryError> {
626 if ids.is_empty() {
627 return Ok(());
628 }
629 for &id in ids {
631 sqlx::query(
632 "UPDATE messages SET deleted_at = datetime('now') WHERE id = ? AND deleted_at IS NULL",
633 )
634 .bind(id)
635 .execute(&self.pool)
636 .await?;
637 }
638 Ok(())
639 }
640
641 pub async fn get_soft_deleted_message_ids(
647 &self,
648 ) -> Result<Vec<MessageId>, crate::error::MemoryError> {
649 let rows: Vec<(MessageId,)> = sqlx::query_as(
650 "SELECT id FROM messages WHERE deleted_at IS NOT NULL AND qdrant_cleaned = 0",
651 )
652 .fetch_all(&self.pool)
653 .await?;
654 Ok(rows.into_iter().map(|(id,)| id).collect())
655 }
656
657 pub async fn mark_qdrant_cleaned(
663 &self,
664 ids: &[MessageId],
665 ) -> Result<(), crate::error::MemoryError> {
666 for &id in ids {
667 sqlx::query("UPDATE messages SET qdrant_cleaned = 1 WHERE id = ?")
668 .bind(id)
669 .execute(&self.pool)
670 .await?;
671 }
672 Ok(())
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 async fn test_store() -> SqliteStore {
681 SqliteStore::new(":memory:").await.unwrap()
682 }
683
684 #[tokio::test]
685 async fn create_conversation_returns_id() {
686 let store = test_store().await;
687 let id1 = store.create_conversation().await.unwrap();
688 let id2 = store.create_conversation().await.unwrap();
689 assert_eq!(id1, ConversationId(1));
690 assert_eq!(id2, ConversationId(2));
691 }
692
693 #[tokio::test]
694 async fn save_and_load_messages() {
695 let store = test_store().await;
696 let cid = store.create_conversation().await.unwrap();
697
698 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
699 let msg_id2 = store
700 .save_message(cid, "assistant", "hi there")
701 .await
702 .unwrap();
703
704 assert_eq!(msg_id1, MessageId(1));
705 assert_eq!(msg_id2, MessageId(2));
706
707 let history = store.load_history(cid, 50).await.unwrap();
708 assert_eq!(history.len(), 2);
709 assert_eq!(history[0].role, Role::User);
710 assert_eq!(history[0].content, "hello");
711 assert_eq!(history[1].role, Role::Assistant);
712 assert_eq!(history[1].content, "hi there");
713 }
714
715 #[tokio::test]
716 async fn load_history_respects_limit() {
717 let store = test_store().await;
718 let cid = store.create_conversation().await.unwrap();
719
720 for i in 0..10 {
721 store
722 .save_message(cid, "user", &format!("msg {i}"))
723 .await
724 .unwrap();
725 }
726
727 let history = store.load_history(cid, 3).await.unwrap();
728 assert_eq!(history.len(), 3);
729 assert_eq!(history[0].content, "msg 7");
730 assert_eq!(history[1].content, "msg 8");
731 assert_eq!(history[2].content, "msg 9");
732 }
733
734 #[tokio::test]
735 async fn latest_conversation_id_empty() {
736 let store = test_store().await;
737 assert!(store.latest_conversation_id().await.unwrap().is_none());
738 }
739
740 #[tokio::test]
741 async fn latest_conversation_id_returns_newest() {
742 let store = test_store().await;
743 store.create_conversation().await.unwrap();
744 let id2 = store.create_conversation().await.unwrap();
745 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
746 }
747
748 #[tokio::test]
749 async fn messages_isolated_per_conversation() {
750 let store = test_store().await;
751 let cid1 = store.create_conversation().await.unwrap();
752 let cid2 = store.create_conversation().await.unwrap();
753
754 store.save_message(cid1, "user", "conv1").await.unwrap();
755 store.save_message(cid2, "user", "conv2").await.unwrap();
756
757 let h1 = store.load_history(cid1, 50).await.unwrap();
758 let h2 = store.load_history(cid2, 50).await.unwrap();
759 assert_eq!(h1.len(), 1);
760 assert_eq!(h1[0].content, "conv1");
761 assert_eq!(h2.len(), 1);
762 assert_eq!(h2[0].content, "conv2");
763 }
764
765 #[tokio::test]
766 async fn pool_accessor_returns_valid_pool() {
767 let store = test_store().await;
768 let pool = store.pool();
769 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
770 assert_eq!(row.0, 1);
771 }
772
773 #[tokio::test]
774 async fn embeddings_metadata_table_exists() {
775 let store = test_store().await;
776 let result: (i64,) = sqlx::query_as(
777 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
778 )
779 .fetch_one(store.pool())
780 .await
781 .unwrap();
782 assert_eq!(result.0, 1);
783 }
784
785 #[tokio::test]
786 async fn cascade_delete_removes_embeddings_metadata() {
787 let store = test_store().await;
788 let pool = store.pool();
789
790 let cid = store.create_conversation().await.unwrap();
791 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
792
793 let point_id = uuid::Uuid::new_v4().to_string();
794 sqlx::query(
795 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
796 VALUES (?, ?, ?)",
797 )
798 .bind(msg_id)
799 .bind(&point_id)
800 .bind(768_i64)
801 .execute(pool)
802 .await
803 .unwrap();
804
805 let before: (i64,) =
806 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
807 .bind(msg_id)
808 .fetch_one(pool)
809 .await
810 .unwrap();
811 assert_eq!(before.0, 1);
812
813 sqlx::query("DELETE FROM messages WHERE id = ?")
814 .bind(msg_id)
815 .execute(pool)
816 .await
817 .unwrap();
818
819 let after: (i64,) =
820 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
821 .bind(msg_id)
822 .fetch_one(pool)
823 .await
824 .unwrap();
825 assert_eq!(after.0, 0);
826 }
827
828 #[tokio::test]
829 async fn messages_by_ids_batch_fetch() {
830 let store = test_store().await;
831 let cid = store.create_conversation().await.unwrap();
832 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
833 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
834 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
835
836 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
837 assert_eq!(results.len(), 2);
838 assert_eq!(results[0].0, id1);
839 assert_eq!(results[0].1.content, "hello");
840 assert_eq!(results[1].0, id2);
841 assert_eq!(results[1].1.content, "hi");
842 }
843
844 #[tokio::test]
845 async fn messages_by_ids_empty_input() {
846 let store = test_store().await;
847 let results = store.messages_by_ids(&[]).await.unwrap();
848 assert!(results.is_empty());
849 }
850
851 #[tokio::test]
852 async fn messages_by_ids_nonexistent() {
853 let store = test_store().await;
854 let results = store
855 .messages_by_ids(&[MessageId(999), MessageId(1000)])
856 .await
857 .unwrap();
858 assert!(results.is_empty());
859 }
860
861 #[tokio::test]
862 async fn message_by_id_fetches_existing() {
863 let store = test_store().await;
864 let cid = store.create_conversation().await.unwrap();
865 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
866
867 let msg = store.message_by_id(msg_id).await.unwrap();
868 assert!(msg.is_some());
869 let msg = msg.unwrap();
870 assert_eq!(msg.role, Role::User);
871 assert_eq!(msg.content, "hello");
872 }
873
874 #[tokio::test]
875 async fn message_by_id_returns_none_for_nonexistent() {
876 let store = test_store().await;
877 let msg = store.message_by_id(MessageId(999)).await.unwrap();
878 assert!(msg.is_none());
879 }
880
881 #[tokio::test]
882 async fn unembedded_message_ids_returns_all_when_none_embedded() {
883 let store = test_store().await;
884 let cid = store.create_conversation().await.unwrap();
885
886 store.save_message(cid, "user", "msg1").await.unwrap();
887 store.save_message(cid, "assistant", "msg2").await.unwrap();
888
889 let unembedded = store.unembedded_message_ids(None).await.unwrap();
890 assert_eq!(unembedded.len(), 2);
891 assert_eq!(unembedded[0].3, "msg1");
892 assert_eq!(unembedded[1].3, "msg2");
893 }
894
895 #[tokio::test]
896 async fn unembedded_message_ids_excludes_embedded() {
897 let store = test_store().await;
898 let pool = store.pool();
899 let cid = store.create_conversation().await.unwrap();
900
901 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
902 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
903
904 let point_id = uuid::Uuid::new_v4().to_string();
905 sqlx::query(
906 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
907 VALUES (?, ?, ?)",
908 )
909 .bind(msg_id1)
910 .bind(&point_id)
911 .bind(768_i64)
912 .execute(pool)
913 .await
914 .unwrap();
915
916 let unembedded = store.unembedded_message_ids(None).await.unwrap();
917 assert_eq!(unembedded.len(), 1);
918 assert_eq!(unembedded[0].0, msg_id2);
919 assert_eq!(unembedded[0].3, "msg2");
920 }
921
922 #[tokio::test]
923 async fn unembedded_message_ids_respects_limit() {
924 let store = test_store().await;
925 let cid = store.create_conversation().await.unwrap();
926
927 for i in 0..10 {
928 store
929 .save_message(cid, "user", &format!("msg{i}"))
930 .await
931 .unwrap();
932 }
933
934 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
935 assert_eq!(unembedded.len(), 3);
936 }
937
938 #[tokio::test]
939 async fn count_messages_returns_correct_count() {
940 let store = test_store().await;
941 let cid = store.create_conversation().await.unwrap();
942
943 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
944
945 store.save_message(cid, "user", "msg1").await.unwrap();
946 store.save_message(cid, "assistant", "msg2").await.unwrap();
947
948 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
949 }
950
951 #[tokio::test]
952 async fn count_messages_after_filters_correctly() {
953 let store = test_store().await;
954 let cid = store.create_conversation().await.unwrap();
955
956 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
957 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
958 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
959
960 assert_eq!(
961 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
962 3
963 );
964 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
965 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
966 }
967
968 #[tokio::test]
969 async fn load_messages_range_basic() {
970 let store = test_store().await;
971 let cid = store.create_conversation().await.unwrap();
972
973 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
974 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
975 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
976
977 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
978 assert_eq!(msgs.len(), 2);
979 assert_eq!(msgs[0].0, msg_id2);
980 assert_eq!(msgs[0].2, "msg2");
981 assert_eq!(msgs[1].0, msg_id3);
982 assert_eq!(msgs[1].2, "msg3");
983 }
984
985 #[tokio::test]
986 async fn load_messages_range_respects_limit() {
987 let store = test_store().await;
988 let cid = store.create_conversation().await.unwrap();
989
990 store.save_message(cid, "user", "msg1").await.unwrap();
991 store.save_message(cid, "assistant", "msg2").await.unwrap();
992 store.save_message(cid, "user", "msg3").await.unwrap();
993
994 let msgs = store
995 .load_messages_range(cid, MessageId(0), 2)
996 .await
997 .unwrap();
998 assert_eq!(msgs.len(), 2);
999 }
1000
1001 #[tokio::test]
1002 async fn keyword_search_basic() {
1003 let store = test_store().await;
1004 let cid = store.create_conversation().await.unwrap();
1005
1006 store
1007 .save_message(cid, "user", "rust programming language")
1008 .await
1009 .unwrap();
1010 store
1011 .save_message(cid, "assistant", "python is great too")
1012 .await
1013 .unwrap();
1014 store
1015 .save_message(cid, "user", "I love rust and cargo")
1016 .await
1017 .unwrap();
1018
1019 let results = store.keyword_search("rust", 10, None).await.unwrap();
1020 assert_eq!(results.len(), 2);
1021 assert!(results.iter().all(|(_, score)| *score > 0.0));
1022 }
1023
1024 #[tokio::test]
1025 async fn keyword_search_with_conversation_filter() {
1026 let store = test_store().await;
1027 let cid1 = store.create_conversation().await.unwrap();
1028 let cid2 = store.create_conversation().await.unwrap();
1029
1030 store
1031 .save_message(cid1, "user", "hello world")
1032 .await
1033 .unwrap();
1034 store
1035 .save_message(cid2, "user", "hello universe")
1036 .await
1037 .unwrap();
1038
1039 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
1040 assert_eq!(results.len(), 1);
1041 }
1042
1043 #[tokio::test]
1044 async fn keyword_search_no_match() {
1045 let store = test_store().await;
1046 let cid = store.create_conversation().await.unwrap();
1047
1048 store
1049 .save_message(cid, "user", "hello world")
1050 .await
1051 .unwrap();
1052
1053 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
1054 assert!(results.is_empty());
1055 }
1056
1057 #[tokio::test]
1058 async fn keyword_search_respects_limit() {
1059 let store = test_store().await;
1060 let cid = store.create_conversation().await.unwrap();
1061
1062 for i in 0..10 {
1063 store
1064 .save_message(cid, "user", &format!("test message {i}"))
1065 .await
1066 .unwrap();
1067 }
1068
1069 let results = store.keyword_search("test", 3, None).await.unwrap();
1070 assert_eq!(results.len(), 3);
1071 }
1072
1073 #[test]
1074 fn sanitize_fts5_query_strips_special_chars() {
1075 assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
1076 assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
1077 assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
1078 assert_eq!(sanitize_fts5_query(" "), "");
1079 assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
1080 }
1081
1082 #[tokio::test]
1083 async fn keyword_search_with_special_chars_does_not_error() {
1084 let store = test_store().await;
1085 let cid = store.create_conversation().await.unwrap();
1086 store
1087 .save_message(cid, "user", "skill audit info")
1088 .await
1089 .unwrap();
1090 store
1093 .keyword_search("skill-audit, confidence=0.1", 10, None)
1094 .await
1095 .unwrap();
1096 }
1097
1098 #[tokio::test]
1099 async fn save_message_with_metadata_stores_visibility() {
1100 let store = test_store().await;
1101 let cid = store.create_conversation().await.unwrap();
1102
1103 let id = store
1104 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
1105 .await
1106 .unwrap();
1107
1108 let history = store.load_history(cid, 10).await.unwrap();
1109 assert_eq!(history.len(), 1);
1110 assert!(!history[0].metadata.agent_visible);
1111 assert!(history[0].metadata.user_visible);
1112 assert_eq!(id, MessageId(1));
1113 }
1114
1115 #[tokio::test]
1116 async fn load_history_filtered_by_agent_visible() {
1117 let store = test_store().await;
1118 let cid = store.create_conversation().await.unwrap();
1119
1120 store
1121 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1122 .await
1123 .unwrap();
1124 store
1125 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1126 .await
1127 .unwrap();
1128
1129 let agent_msgs = store
1130 .load_history_filtered(cid, 50, Some(true), None)
1131 .await
1132 .unwrap();
1133 assert_eq!(agent_msgs.len(), 1);
1134 assert_eq!(agent_msgs[0].content, "visible to agent");
1135 }
1136
1137 #[tokio::test]
1138 async fn load_history_filtered_by_user_visible() {
1139 let store = test_store().await;
1140 let cid = store.create_conversation().await.unwrap();
1141
1142 store
1143 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1144 .await
1145 .unwrap();
1146 store
1147 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1148 .await
1149 .unwrap();
1150
1151 let user_msgs = store
1152 .load_history_filtered(cid, 50, None, Some(true))
1153 .await
1154 .unwrap();
1155 assert_eq!(user_msgs.len(), 1);
1156 assert_eq!(user_msgs[0].content, "user sees this");
1157 }
1158
1159 #[tokio::test]
1160 async fn load_history_filtered_no_filter_returns_all() {
1161 let store = test_store().await;
1162 let cid = store.create_conversation().await.unwrap();
1163
1164 store
1165 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1166 .await
1167 .unwrap();
1168 store
1169 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1170 .await
1171 .unwrap();
1172
1173 let all_msgs = store
1174 .load_history_filtered(cid, 50, None, None)
1175 .await
1176 .unwrap();
1177 assert_eq!(all_msgs.len(), 2);
1178 }
1179
1180 #[tokio::test]
1181 async fn replace_conversation_marks_originals_and_inserts_summary() {
1182 let store = test_store().await;
1183 let cid = store.create_conversation().await.unwrap();
1184
1185 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1186 let id2 = store
1187 .save_message(cid, "assistant", "second")
1188 .await
1189 .unwrap();
1190 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1191
1192 let summary_id = store
1193 .replace_conversation(cid, id1..=id2, "system", "summary text")
1194 .await
1195 .unwrap();
1196
1197 let all = store.load_history(cid, 50).await.unwrap();
1199 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1201 assert!(!by_id1.metadata.agent_visible);
1202 assert!(by_id1.metadata.user_visible);
1203
1204 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1205 assert!(!by_id2.metadata.agent_visible);
1206
1207 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1208 assert!(by_id3.metadata.agent_visible);
1209
1210 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1212 assert!(summary.metadata.agent_visible);
1213 assert!(!summary.metadata.user_visible);
1214 assert!(summary_id > id3);
1215 }
1216
1217 #[tokio::test]
1218 async fn oldest_message_ids_returns_in_order() {
1219 let store = test_store().await;
1220 let cid = store.create_conversation().await.unwrap();
1221
1222 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1223 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1224 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1225
1226 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1227 assert_eq!(ids, vec![id1, id2]);
1228 assert!(ids[0] < ids[1]);
1229
1230 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1231 assert_eq!(all_ids, vec![id1, id2, id3]);
1232 }
1233
1234 #[tokio::test]
1235 async fn message_metadata_default_both_visible() {
1236 let store = test_store().await;
1237 let cid = store.create_conversation().await.unwrap();
1238
1239 store.save_message(cid, "user", "normal").await.unwrap();
1240
1241 let history = store.load_history(cid, 10).await.unwrap();
1242 assert!(history[0].metadata.agent_visible);
1243 assert!(history[0].metadata.user_visible);
1244 assert!(history[0].metadata.compacted_at.is_none());
1245 }
1246
1247 #[tokio::test]
1248 async fn load_history_empty_parts_json_fast_path() {
1249 let store = test_store().await;
1250 let cid = store.create_conversation().await.unwrap();
1251
1252 store
1253 .save_message_with_parts(cid, "user", "hello", "[]")
1254 .await
1255 .unwrap();
1256
1257 let history = store.load_history(cid, 10).await.unwrap();
1258 assert_eq!(history.len(), 1);
1259 assert!(
1260 history[0].parts.is_empty(),
1261 "\"[]\" fast-path must yield empty parts Vec"
1262 );
1263 }
1264
1265 #[tokio::test]
1266 async fn load_history_non_empty_parts_json_parsed() {
1267 let store = test_store().await;
1268 let cid = store.create_conversation().await.unwrap();
1269
1270 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1271 tool_use_id: "t1".into(),
1272 content: "result".into(),
1273 is_error: false,
1274 }])
1275 .unwrap();
1276
1277 store
1278 .save_message_with_parts(cid, "user", "hello", &parts_json)
1279 .await
1280 .unwrap();
1281
1282 let history = store.load_history(cid, 10).await.unwrap();
1283 assert_eq!(history.len(), 1);
1284 assert_eq!(history[0].parts.len(), 1);
1285 assert!(
1286 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1287 );
1288 }
1289
1290 #[tokio::test]
1291 async fn message_by_id_empty_parts_json_fast_path() {
1292 let store = test_store().await;
1293 let cid = store.create_conversation().await.unwrap();
1294
1295 let id = store
1296 .save_message_with_parts(cid, "user", "msg", "[]")
1297 .await
1298 .unwrap();
1299
1300 let msg = store.message_by_id(id).await.unwrap().unwrap();
1301 assert!(
1302 msg.parts.is_empty(),
1303 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1304 );
1305 }
1306
1307 #[tokio::test]
1308 async fn messages_by_ids_empty_parts_json_fast_path() {
1309 let store = test_store().await;
1310 let cid = store.create_conversation().await.unwrap();
1311
1312 let id = store
1313 .save_message_with_parts(cid, "user", "msg", "[]")
1314 .await
1315 .unwrap();
1316
1317 let results = store.messages_by_ids(&[id]).await.unwrap();
1318 assert_eq!(results.len(), 1);
1319 assert!(
1320 results[0].1.parts.is_empty(),
1321 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1322 );
1323 }
1324
1325 #[tokio::test]
1326 async fn load_history_filtered_empty_parts_json_fast_path() {
1327 let store = test_store().await;
1328 let cid = store.create_conversation().await.unwrap();
1329
1330 store
1331 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1332 .await
1333 .unwrap();
1334
1335 let msgs = store
1336 .load_history_filtered(cid, 10, Some(true), None)
1337 .await
1338 .unwrap();
1339 assert_eq!(msgs.len(), 1);
1340 assert!(
1341 msgs[0].parts.is_empty(),
1342 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1343 );
1344 }
1345}