1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10fn sanitize_fts5_query(query: &str) -> String {
16 query
17 .split(|c: char| !c.is_alphanumeric())
18 .filter(|t| !t.is_empty())
19 .collect::<Vec<_>>()
20 .join(" ")
21}
22
23fn parse_role(s: &str) -> Role {
24 match s {
25 "assistant" => Role::Assistant,
26 "system" => Role::System,
27 _ => Role::User,
28 }
29}
30
31#[must_use]
32pub fn role_str(role: Role) -> &'static str {
33 match role {
34 Role::System => "system",
35 Role::User => "user",
36 Role::Assistant => "assistant",
37 }
38}
39
40impl SqliteStore {
41 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
47 let row: (ConversationId,) =
48 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
49 .fetch_one(&self.pool)
50 .await?;
51 Ok(row.0)
52 }
53
54 pub async fn save_message(
60 &self,
61 conversation_id: ConversationId,
62 role: &str,
63 content: &str,
64 ) -> Result<MessageId, MemoryError> {
65 self.save_message_with_parts(conversation_id, role, content, "[]")
66 .await
67 }
68
69 pub async fn save_message_with_parts(
75 &self,
76 conversation_id: ConversationId,
77 role: &str,
78 content: &str,
79 parts_json: &str,
80 ) -> Result<MessageId, MemoryError> {
81 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
82 .await
83 }
84
85 pub async fn save_message_with_metadata(
91 &self,
92 conversation_id: ConversationId,
93 role: &str,
94 content: &str,
95 parts_json: &str,
96 agent_visible: bool,
97 user_visible: bool,
98 ) -> Result<MessageId, MemoryError> {
99 let row: (MessageId,) = sqlx::query_as(
100 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
101 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
102 )
103 .bind(conversation_id)
104 .bind(role)
105 .bind(content)
106 .bind(parts_json)
107 .bind(i64::from(agent_visible))
108 .bind(i64::from(user_visible))
109 .fetch_one(&self.pool)
110 .await?;
111 Ok(row.0)
112 }
113
114 pub async fn load_history(
120 &self,
121 conversation_id: ConversationId,
122 limit: u32,
123 ) -> Result<Vec<Message>, MemoryError> {
124 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
125 "SELECT role, content, parts, agent_visible, user_visible FROM (\
126 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
127 WHERE conversation_id = ? \
128 ORDER BY id DESC \
129 LIMIT ?\
130 ) ORDER BY id ASC",
131 )
132 .bind(conversation_id)
133 .bind(limit)
134 .fetch_all(&self.pool)
135 .await?;
136
137 let messages = rows
138 .into_iter()
139 .map(
140 |(role_str, content, parts_json, agent_visible, user_visible)| {
141 let parts: Vec<MessagePart> = if parts_json == "[]" {
142 vec![]
143 } else {
144 serde_json::from_str(&parts_json).unwrap_or_default()
145 };
146 Message {
147 role: parse_role(&role_str),
148 content,
149 parts,
150 metadata: MessageMetadata {
151 agent_visible: agent_visible != 0,
152 user_visible: user_visible != 0,
153 compacted_at: None,
154 },
155 }
156 },
157 )
158 .collect();
159 Ok(messages)
160 }
161
162 pub async fn load_history_filtered(
170 &self,
171 conversation_id: ConversationId,
172 limit: u32,
173 agent_visible: Option<bool>,
174 user_visible: Option<bool>,
175 ) -> Result<Vec<Message>, MemoryError> {
176 let av = agent_visible.map(i64::from);
177 let uv = user_visible.map(i64::from);
178
179 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
180 "WITH recent AS (\
181 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
182 WHERE conversation_id = ? \
183 AND (? IS NULL OR agent_visible = ?) \
184 AND (? IS NULL OR user_visible = ?) \
185 ORDER BY id DESC \
186 LIMIT ?\
187 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
188 )
189 .bind(conversation_id)
190 .bind(av)
191 .bind(av)
192 .bind(uv)
193 .bind(uv)
194 .bind(limit)
195 .fetch_all(&self.pool)
196 .await?;
197
198 let messages = rows
199 .into_iter()
200 .map(
201 |(role_str, content, parts_json, agent_visible, user_visible)| {
202 let parts: Vec<MessagePart> = if parts_json == "[]" {
203 vec![]
204 } else {
205 serde_json::from_str(&parts_json).unwrap_or_default()
206 };
207 Message {
208 role: parse_role(&role_str),
209 content,
210 parts,
211 metadata: MessageMetadata {
212 agent_visible: agent_visible != 0,
213 user_visible: user_visible != 0,
214 compacted_at: None,
215 },
216 }
217 },
218 )
219 .collect();
220 Ok(messages)
221 }
222
223 pub async fn replace_conversation(
235 &self,
236 conversation_id: ConversationId,
237 compacted_range: std::ops::RangeInclusive<MessageId>,
238 summary_role: &str,
239 summary_content: &str,
240 ) -> Result<MessageId, MemoryError> {
241 let now = {
242 let secs = std::time::SystemTime::now()
243 .duration_since(std::time::UNIX_EPOCH)
244 .unwrap_or_default()
245 .as_secs();
246 format!("{secs}")
247 };
248 let start_id = compacted_range.start().0;
249 let end_id = compacted_range.end().0;
250
251 let mut tx = self.pool.begin().await?;
252
253 sqlx::query(
254 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
255 WHERE conversation_id = ? AND id >= ? AND id <= ?",
256 )
257 .bind(&now)
258 .bind(conversation_id)
259 .bind(start_id)
260 .bind(end_id)
261 .execute(&mut *tx)
262 .await?;
263
264 let row: (MessageId,) = sqlx::query_as(
265 "INSERT INTO messages \
266 (conversation_id, role, content, parts, agent_visible, user_visible) \
267 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
268 )
269 .bind(conversation_id)
270 .bind(summary_role)
271 .bind(summary_content)
272 .fetch_one(&mut *tx)
273 .await?;
274
275 tx.commit().await?;
276
277 Ok(row.0)
278 }
279
280 pub async fn oldest_message_ids(
286 &self,
287 conversation_id: ConversationId,
288 n: u32,
289 ) -> Result<Vec<MessageId>, MemoryError> {
290 let rows: Vec<(MessageId,)> = sqlx::query_as(
291 "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
292 )
293 .bind(conversation_id)
294 .bind(n)
295 .fetch_all(&self.pool)
296 .await?;
297 Ok(rows.into_iter().map(|r| r.0).collect())
298 }
299
300 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
306 let row: Option<(ConversationId,)> =
307 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
308 .fetch_optional(&self.pool)
309 .await?;
310 Ok(row.map(|r| r.0))
311 }
312
313 pub async fn message_by_id(
319 &self,
320 message_id: MessageId,
321 ) -> Result<Option<Message>, MemoryError> {
322 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
323 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
324 )
325 .bind(message_id)
326 .fetch_optional(&self.pool)
327 .await?;
328
329 Ok(row.map(
330 |(role_str, content, parts_json, agent_visible, user_visible)| {
331 let parts: Vec<MessagePart> = if parts_json == "[]" {
332 vec![]
333 } else {
334 serde_json::from_str(&parts_json).unwrap_or_default()
335 };
336 Message {
337 role: parse_role(&role_str),
338 content,
339 parts,
340 metadata: MessageMetadata {
341 agent_visible: agent_visible != 0,
342 user_visible: user_visible != 0,
343 compacted_at: None,
344 },
345 }
346 },
347 ))
348 }
349
350 pub async fn messages_by_ids(
356 &self,
357 ids: &[MessageId],
358 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
359 if ids.is_empty() {
360 return Ok(Vec::new());
361 }
362
363 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
364
365 let query = format!(
366 "SELECT id, role, content, parts FROM messages \
367 WHERE id IN ({placeholders}) AND agent_visible = 1"
368 );
369 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
370 for &id in ids {
371 q = q.bind(id);
372 }
373
374 let rows = q.fetch_all(&self.pool).await?;
375
376 Ok(rows
377 .into_iter()
378 .map(|(id, role_str, content, parts_json)| {
379 let parts: Vec<MessagePart> = if parts_json == "[]" {
380 vec![]
381 } else {
382 serde_json::from_str(&parts_json).unwrap_or_default()
383 };
384 (
385 id,
386 Message {
387 role: parse_role(&role_str),
388 content,
389 parts,
390 metadata: MessageMetadata::default(),
391 },
392 )
393 })
394 .collect())
395 }
396
397 pub async fn unembedded_message_ids(
403 &self,
404 limit: Option<usize>,
405 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
406 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
407
408 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
409 "SELECT m.id, m.conversation_id, m.role, m.content \
410 FROM messages m \
411 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
412 WHERE em.id IS NULL \
413 ORDER BY m.id ASC \
414 LIMIT ?",
415 )
416 .bind(effective_limit)
417 .fetch_all(&self.pool)
418 .await?;
419
420 Ok(rows)
421 }
422
423 pub async fn count_messages(
429 &self,
430 conversation_id: ConversationId,
431 ) -> Result<i64, MemoryError> {
432 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
433 .bind(conversation_id)
434 .fetch_one(&self.pool)
435 .await?;
436 Ok(row.0)
437 }
438
439 pub async fn count_messages_after(
445 &self,
446 conversation_id: ConversationId,
447 after_id: MessageId,
448 ) -> Result<i64, MemoryError> {
449 let row: (i64,) =
450 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
451 .bind(conversation_id)
452 .bind(after_id)
453 .fetch_one(&self.pool)
454 .await?;
455 Ok(row.0)
456 }
457
458 pub async fn keyword_search(
467 &self,
468 query: &str,
469 limit: usize,
470 conversation_id: Option<ConversationId>,
471 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
472 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
473 let safe_query = sanitize_fts5_query(query);
474 if safe_query.is_empty() {
475 return Ok(Vec::new());
476 }
477
478 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
479 sqlx::query_as(
480 "SELECT m.id, -rank AS score \
481 FROM messages_fts f \
482 JOIN messages m ON m.id = f.rowid \
483 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
484 ORDER BY rank \
485 LIMIT ?",
486 )
487 .bind(&safe_query)
488 .bind(cid)
489 .bind(effective_limit)
490 .fetch_all(&self.pool)
491 .await?
492 } else {
493 sqlx::query_as(
494 "SELECT m.id, -rank AS score \
495 FROM messages_fts f \
496 JOIN messages m ON m.id = f.rowid \
497 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
498 ORDER BY rank \
499 LIMIT ?",
500 )
501 .bind(&safe_query)
502 .bind(effective_limit)
503 .fetch_all(&self.pool)
504 .await?
505 };
506
507 Ok(rows)
508 }
509
510 pub async fn message_timestamps(
518 &self,
519 ids: &[MessageId],
520 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
521 if ids.is_empty() {
522 return Ok(std::collections::HashMap::new());
523 }
524
525 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
526 let query = format!(
527 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
528 FROM messages WHERE id IN ({placeholders})"
529 );
530 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
531 for &id in ids {
532 q = q.bind(id);
533 }
534
535 let rows = q.fetch_all(&self.pool).await?;
536 Ok(rows.into_iter().collect())
537 }
538
539 pub async fn load_messages_range(
545 &self,
546 conversation_id: ConversationId,
547 after_message_id: MessageId,
548 limit: usize,
549 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
550 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
551
552 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
553 "SELECT id, role, content FROM messages \
554 WHERE conversation_id = ? AND id > ? \
555 ORDER BY id ASC LIMIT ?",
556 )
557 .bind(conversation_id)
558 .bind(after_message_id)
559 .bind(effective_limit)
560 .fetch_all(&self.pool)
561 .await?;
562
563 Ok(rows)
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 async fn test_store() -> SqliteStore {
572 SqliteStore::new(":memory:").await.unwrap()
573 }
574
575 #[tokio::test]
576 async fn create_conversation_returns_id() {
577 let store = test_store().await;
578 let id1 = store.create_conversation().await.unwrap();
579 let id2 = store.create_conversation().await.unwrap();
580 assert_eq!(id1, ConversationId(1));
581 assert_eq!(id2, ConversationId(2));
582 }
583
584 #[tokio::test]
585 async fn save_and_load_messages() {
586 let store = test_store().await;
587 let cid = store.create_conversation().await.unwrap();
588
589 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
590 let msg_id2 = store
591 .save_message(cid, "assistant", "hi there")
592 .await
593 .unwrap();
594
595 assert_eq!(msg_id1, MessageId(1));
596 assert_eq!(msg_id2, MessageId(2));
597
598 let history = store.load_history(cid, 50).await.unwrap();
599 assert_eq!(history.len(), 2);
600 assert_eq!(history[0].role, Role::User);
601 assert_eq!(history[0].content, "hello");
602 assert_eq!(history[1].role, Role::Assistant);
603 assert_eq!(history[1].content, "hi there");
604 }
605
606 #[tokio::test]
607 async fn load_history_respects_limit() {
608 let store = test_store().await;
609 let cid = store.create_conversation().await.unwrap();
610
611 for i in 0..10 {
612 store
613 .save_message(cid, "user", &format!("msg {i}"))
614 .await
615 .unwrap();
616 }
617
618 let history = store.load_history(cid, 3).await.unwrap();
619 assert_eq!(history.len(), 3);
620 assert_eq!(history[0].content, "msg 7");
621 assert_eq!(history[1].content, "msg 8");
622 assert_eq!(history[2].content, "msg 9");
623 }
624
625 #[tokio::test]
626 async fn latest_conversation_id_empty() {
627 let store = test_store().await;
628 assert!(store.latest_conversation_id().await.unwrap().is_none());
629 }
630
631 #[tokio::test]
632 async fn latest_conversation_id_returns_newest() {
633 let store = test_store().await;
634 store.create_conversation().await.unwrap();
635 let id2 = store.create_conversation().await.unwrap();
636 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
637 }
638
639 #[tokio::test]
640 async fn messages_isolated_per_conversation() {
641 let store = test_store().await;
642 let cid1 = store.create_conversation().await.unwrap();
643 let cid2 = store.create_conversation().await.unwrap();
644
645 store.save_message(cid1, "user", "conv1").await.unwrap();
646 store.save_message(cid2, "user", "conv2").await.unwrap();
647
648 let h1 = store.load_history(cid1, 50).await.unwrap();
649 let h2 = store.load_history(cid2, 50).await.unwrap();
650 assert_eq!(h1.len(), 1);
651 assert_eq!(h1[0].content, "conv1");
652 assert_eq!(h2.len(), 1);
653 assert_eq!(h2[0].content, "conv2");
654 }
655
656 #[tokio::test]
657 async fn pool_accessor_returns_valid_pool() {
658 let store = test_store().await;
659 let pool = store.pool();
660 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
661 assert_eq!(row.0, 1);
662 }
663
664 #[tokio::test]
665 async fn embeddings_metadata_table_exists() {
666 let store = test_store().await;
667 let result: (i64,) = sqlx::query_as(
668 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
669 )
670 .fetch_one(store.pool())
671 .await
672 .unwrap();
673 assert_eq!(result.0, 1);
674 }
675
676 #[tokio::test]
677 async fn cascade_delete_removes_embeddings_metadata() {
678 let store = test_store().await;
679 let pool = store.pool();
680
681 let cid = store.create_conversation().await.unwrap();
682 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
683
684 let point_id = uuid::Uuid::new_v4().to_string();
685 sqlx::query(
686 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
687 VALUES (?, ?, ?)",
688 )
689 .bind(msg_id)
690 .bind(&point_id)
691 .bind(768_i64)
692 .execute(pool)
693 .await
694 .unwrap();
695
696 let before: (i64,) =
697 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
698 .bind(msg_id)
699 .fetch_one(pool)
700 .await
701 .unwrap();
702 assert_eq!(before.0, 1);
703
704 sqlx::query("DELETE FROM messages WHERE id = ?")
705 .bind(msg_id)
706 .execute(pool)
707 .await
708 .unwrap();
709
710 let after: (i64,) =
711 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
712 .bind(msg_id)
713 .fetch_one(pool)
714 .await
715 .unwrap();
716 assert_eq!(after.0, 0);
717 }
718
719 #[tokio::test]
720 async fn messages_by_ids_batch_fetch() {
721 let store = test_store().await;
722 let cid = store.create_conversation().await.unwrap();
723 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
724 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
725 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
726
727 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
728 assert_eq!(results.len(), 2);
729 assert_eq!(results[0].0, id1);
730 assert_eq!(results[0].1.content, "hello");
731 assert_eq!(results[1].0, id2);
732 assert_eq!(results[1].1.content, "hi");
733 }
734
735 #[tokio::test]
736 async fn messages_by_ids_empty_input() {
737 let store = test_store().await;
738 let results = store.messages_by_ids(&[]).await.unwrap();
739 assert!(results.is_empty());
740 }
741
742 #[tokio::test]
743 async fn messages_by_ids_nonexistent() {
744 let store = test_store().await;
745 let results = store
746 .messages_by_ids(&[MessageId(999), MessageId(1000)])
747 .await
748 .unwrap();
749 assert!(results.is_empty());
750 }
751
752 #[tokio::test]
753 async fn message_by_id_fetches_existing() {
754 let store = test_store().await;
755 let cid = store.create_conversation().await.unwrap();
756 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
757
758 let msg = store.message_by_id(msg_id).await.unwrap();
759 assert!(msg.is_some());
760 let msg = msg.unwrap();
761 assert_eq!(msg.role, Role::User);
762 assert_eq!(msg.content, "hello");
763 }
764
765 #[tokio::test]
766 async fn message_by_id_returns_none_for_nonexistent() {
767 let store = test_store().await;
768 let msg = store.message_by_id(MessageId(999)).await.unwrap();
769 assert!(msg.is_none());
770 }
771
772 #[tokio::test]
773 async fn unembedded_message_ids_returns_all_when_none_embedded() {
774 let store = test_store().await;
775 let cid = store.create_conversation().await.unwrap();
776
777 store.save_message(cid, "user", "msg1").await.unwrap();
778 store.save_message(cid, "assistant", "msg2").await.unwrap();
779
780 let unembedded = store.unembedded_message_ids(None).await.unwrap();
781 assert_eq!(unembedded.len(), 2);
782 assert_eq!(unembedded[0].3, "msg1");
783 assert_eq!(unembedded[1].3, "msg2");
784 }
785
786 #[tokio::test]
787 async fn unembedded_message_ids_excludes_embedded() {
788 let store = test_store().await;
789 let pool = store.pool();
790 let cid = store.create_conversation().await.unwrap();
791
792 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
793 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
794
795 let point_id = uuid::Uuid::new_v4().to_string();
796 sqlx::query(
797 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
798 VALUES (?, ?, ?)",
799 )
800 .bind(msg_id1)
801 .bind(&point_id)
802 .bind(768_i64)
803 .execute(pool)
804 .await
805 .unwrap();
806
807 let unembedded = store.unembedded_message_ids(None).await.unwrap();
808 assert_eq!(unembedded.len(), 1);
809 assert_eq!(unembedded[0].0, msg_id2);
810 assert_eq!(unembedded[0].3, "msg2");
811 }
812
813 #[tokio::test]
814 async fn unembedded_message_ids_respects_limit() {
815 let store = test_store().await;
816 let cid = store.create_conversation().await.unwrap();
817
818 for i in 0..10 {
819 store
820 .save_message(cid, "user", &format!("msg{i}"))
821 .await
822 .unwrap();
823 }
824
825 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
826 assert_eq!(unembedded.len(), 3);
827 }
828
829 #[tokio::test]
830 async fn count_messages_returns_correct_count() {
831 let store = test_store().await;
832 let cid = store.create_conversation().await.unwrap();
833
834 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
835
836 store.save_message(cid, "user", "msg1").await.unwrap();
837 store.save_message(cid, "assistant", "msg2").await.unwrap();
838
839 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
840 }
841
842 #[tokio::test]
843 async fn count_messages_after_filters_correctly() {
844 let store = test_store().await;
845 let cid = store.create_conversation().await.unwrap();
846
847 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
848 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
849 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
850
851 assert_eq!(
852 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
853 3
854 );
855 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
856 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
857 }
858
859 #[tokio::test]
860 async fn load_messages_range_basic() {
861 let store = test_store().await;
862 let cid = store.create_conversation().await.unwrap();
863
864 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
865 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
866 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
867
868 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
869 assert_eq!(msgs.len(), 2);
870 assert_eq!(msgs[0].0, msg_id2);
871 assert_eq!(msgs[0].2, "msg2");
872 assert_eq!(msgs[1].0, msg_id3);
873 assert_eq!(msgs[1].2, "msg3");
874 }
875
876 #[tokio::test]
877 async fn load_messages_range_respects_limit() {
878 let store = test_store().await;
879 let cid = store.create_conversation().await.unwrap();
880
881 store.save_message(cid, "user", "msg1").await.unwrap();
882 store.save_message(cid, "assistant", "msg2").await.unwrap();
883 store.save_message(cid, "user", "msg3").await.unwrap();
884
885 let msgs = store
886 .load_messages_range(cid, MessageId(0), 2)
887 .await
888 .unwrap();
889 assert_eq!(msgs.len(), 2);
890 }
891
892 #[tokio::test]
893 async fn keyword_search_basic() {
894 let store = test_store().await;
895 let cid = store.create_conversation().await.unwrap();
896
897 store
898 .save_message(cid, "user", "rust programming language")
899 .await
900 .unwrap();
901 store
902 .save_message(cid, "assistant", "python is great too")
903 .await
904 .unwrap();
905 store
906 .save_message(cid, "user", "I love rust and cargo")
907 .await
908 .unwrap();
909
910 let results = store.keyword_search("rust", 10, None).await.unwrap();
911 assert_eq!(results.len(), 2);
912 assert!(results.iter().all(|(_, score)| *score > 0.0));
913 }
914
915 #[tokio::test]
916 async fn keyword_search_with_conversation_filter() {
917 let store = test_store().await;
918 let cid1 = store.create_conversation().await.unwrap();
919 let cid2 = store.create_conversation().await.unwrap();
920
921 store
922 .save_message(cid1, "user", "hello world")
923 .await
924 .unwrap();
925 store
926 .save_message(cid2, "user", "hello universe")
927 .await
928 .unwrap();
929
930 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
931 assert_eq!(results.len(), 1);
932 }
933
934 #[tokio::test]
935 async fn keyword_search_no_match() {
936 let store = test_store().await;
937 let cid = store.create_conversation().await.unwrap();
938
939 store
940 .save_message(cid, "user", "hello world")
941 .await
942 .unwrap();
943
944 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
945 assert!(results.is_empty());
946 }
947
948 #[tokio::test]
949 async fn keyword_search_respects_limit() {
950 let store = test_store().await;
951 let cid = store.create_conversation().await.unwrap();
952
953 for i in 0..10 {
954 store
955 .save_message(cid, "user", &format!("test message {i}"))
956 .await
957 .unwrap();
958 }
959
960 let results = store.keyword_search("test", 3, None).await.unwrap();
961 assert_eq!(results.len(), 3);
962 }
963
964 #[test]
965 fn sanitize_fts5_query_strips_special_chars() {
966 assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
967 assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
968 assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
969 assert_eq!(sanitize_fts5_query(" "), "");
970 assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
971 }
972
973 #[tokio::test]
974 async fn keyword_search_with_special_chars_does_not_error() {
975 let store = test_store().await;
976 let cid = store.create_conversation().await.unwrap();
977 store
978 .save_message(cid, "user", "skill audit info")
979 .await
980 .unwrap();
981 store
984 .keyword_search("skill-audit, confidence=0.1", 10, None)
985 .await
986 .unwrap();
987 }
988
989 #[tokio::test]
990 async fn save_message_with_metadata_stores_visibility() {
991 let store = test_store().await;
992 let cid = store.create_conversation().await.unwrap();
993
994 let id = store
995 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
996 .await
997 .unwrap();
998
999 let history = store.load_history(cid, 10).await.unwrap();
1000 assert_eq!(history.len(), 1);
1001 assert!(!history[0].metadata.agent_visible);
1002 assert!(history[0].metadata.user_visible);
1003 assert_eq!(id, MessageId(1));
1004 }
1005
1006 #[tokio::test]
1007 async fn load_history_filtered_by_agent_visible() {
1008 let store = test_store().await;
1009 let cid = store.create_conversation().await.unwrap();
1010
1011 store
1012 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1013 .await
1014 .unwrap();
1015 store
1016 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1017 .await
1018 .unwrap();
1019
1020 let agent_msgs = store
1021 .load_history_filtered(cid, 50, Some(true), None)
1022 .await
1023 .unwrap();
1024 assert_eq!(agent_msgs.len(), 1);
1025 assert_eq!(agent_msgs[0].content, "visible to agent");
1026 }
1027
1028 #[tokio::test]
1029 async fn load_history_filtered_by_user_visible() {
1030 let store = test_store().await;
1031 let cid = store.create_conversation().await.unwrap();
1032
1033 store
1034 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1035 .await
1036 .unwrap();
1037 store
1038 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1039 .await
1040 .unwrap();
1041
1042 let user_msgs = store
1043 .load_history_filtered(cid, 50, None, Some(true))
1044 .await
1045 .unwrap();
1046 assert_eq!(user_msgs.len(), 1);
1047 assert_eq!(user_msgs[0].content, "user sees this");
1048 }
1049
1050 #[tokio::test]
1051 async fn load_history_filtered_no_filter_returns_all() {
1052 let store = test_store().await;
1053 let cid = store.create_conversation().await.unwrap();
1054
1055 store
1056 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1057 .await
1058 .unwrap();
1059 store
1060 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1061 .await
1062 .unwrap();
1063
1064 let all_msgs = store
1065 .load_history_filtered(cid, 50, None, None)
1066 .await
1067 .unwrap();
1068 assert_eq!(all_msgs.len(), 2);
1069 }
1070
1071 #[tokio::test]
1072 async fn replace_conversation_marks_originals_and_inserts_summary() {
1073 let store = test_store().await;
1074 let cid = store.create_conversation().await.unwrap();
1075
1076 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1077 let id2 = store
1078 .save_message(cid, "assistant", "second")
1079 .await
1080 .unwrap();
1081 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1082
1083 let summary_id = store
1084 .replace_conversation(cid, id1..=id2, "system", "summary text")
1085 .await
1086 .unwrap();
1087
1088 let all = store.load_history(cid, 50).await.unwrap();
1090 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1092 assert!(!by_id1.metadata.agent_visible);
1093 assert!(by_id1.metadata.user_visible);
1094
1095 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1096 assert!(!by_id2.metadata.agent_visible);
1097
1098 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1099 assert!(by_id3.metadata.agent_visible);
1100
1101 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1103 assert!(summary.metadata.agent_visible);
1104 assert!(!summary.metadata.user_visible);
1105 assert!(summary_id > id3);
1106 }
1107
1108 #[tokio::test]
1109 async fn oldest_message_ids_returns_in_order() {
1110 let store = test_store().await;
1111 let cid = store.create_conversation().await.unwrap();
1112
1113 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1114 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1115 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1116
1117 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1118 assert_eq!(ids, vec![id1, id2]);
1119 assert!(ids[0] < ids[1]);
1120
1121 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1122 assert_eq!(all_ids, vec![id1, id2, id3]);
1123 }
1124
1125 #[tokio::test]
1126 async fn message_metadata_default_both_visible() {
1127 let store = test_store().await;
1128 let cid = store.create_conversation().await.unwrap();
1129
1130 store.save_message(cid, "user", "normal").await.unwrap();
1131
1132 let history = store.load_history(cid, 10).await.unwrap();
1133 assert!(history[0].metadata.agent_visible);
1134 assert!(history[0].metadata.user_visible);
1135 assert!(history[0].metadata.compacted_at.is_none());
1136 }
1137
1138 #[tokio::test]
1139 async fn load_history_empty_parts_json_fast_path() {
1140 let store = test_store().await;
1141 let cid = store.create_conversation().await.unwrap();
1142
1143 store
1144 .save_message_with_parts(cid, "user", "hello", "[]")
1145 .await
1146 .unwrap();
1147
1148 let history = store.load_history(cid, 10).await.unwrap();
1149 assert_eq!(history.len(), 1);
1150 assert!(
1151 history[0].parts.is_empty(),
1152 "\"[]\" fast-path must yield empty parts Vec"
1153 );
1154 }
1155
1156 #[tokio::test]
1157 async fn load_history_non_empty_parts_json_parsed() {
1158 let store = test_store().await;
1159 let cid = store.create_conversation().await.unwrap();
1160
1161 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1162 tool_use_id: "t1".into(),
1163 content: "result".into(),
1164 is_error: false,
1165 }])
1166 .unwrap();
1167
1168 store
1169 .save_message_with_parts(cid, "user", "hello", &parts_json)
1170 .await
1171 .unwrap();
1172
1173 let history = store.load_history(cid, 10).await.unwrap();
1174 assert_eq!(history.len(), 1);
1175 assert_eq!(history[0].parts.len(), 1);
1176 assert!(
1177 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1178 );
1179 }
1180
1181 #[tokio::test]
1182 async fn message_by_id_empty_parts_json_fast_path() {
1183 let store = test_store().await;
1184 let cid = store.create_conversation().await.unwrap();
1185
1186 let id = store
1187 .save_message_with_parts(cid, "user", "msg", "[]")
1188 .await
1189 .unwrap();
1190
1191 let msg = store.message_by_id(id).await.unwrap().unwrap();
1192 assert!(
1193 msg.parts.is_empty(),
1194 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1195 );
1196 }
1197
1198 #[tokio::test]
1199 async fn messages_by_ids_empty_parts_json_fast_path() {
1200 let store = test_store().await;
1201 let cid = store.create_conversation().await.unwrap();
1202
1203 let id = store
1204 .save_message_with_parts(cid, "user", "msg", "[]")
1205 .await
1206 .unwrap();
1207
1208 let results = store.messages_by_ids(&[id]).await.unwrap();
1209 assert_eq!(results.len(), 1);
1210 assert!(
1211 results[0].1.parts.is_empty(),
1212 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1213 );
1214 }
1215
1216 #[tokio::test]
1217 async fn load_history_filtered_empty_parts_json_fast_path() {
1218 let store = test_store().await;
1219 let cid = store.create_conversation().await.unwrap();
1220
1221 store
1222 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1223 .await
1224 .unwrap();
1225
1226 let msgs = store
1227 .load_history_filtered(cid, 10, Some(true), None)
1228 .await
1229 .unwrap();
1230 assert_eq!(msgs.len(), 1);
1231 assert!(
1232 msgs[0].parts.is_empty(),
1233 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1234 );
1235 }
1236}