1use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10fn parse_role(s: &str) -> Role {
11 match s {
12 "assistant" => Role::Assistant,
13 "system" => Role::System,
14 _ => Role::User,
15 }
16}
17
18#[must_use]
19pub fn role_str(role: Role) -> &'static str {
20 match role {
21 Role::System => "system",
22 Role::User => "user",
23 Role::Assistant => "assistant",
24 }
25}
26
27impl SqliteStore {
28 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
34 let row: (ConversationId,) =
35 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
36 .fetch_one(&self.pool)
37 .await?;
38 Ok(row.0)
39 }
40
41 pub async fn save_message(
47 &self,
48 conversation_id: ConversationId,
49 role: &str,
50 content: &str,
51 ) -> Result<MessageId, MemoryError> {
52 self.save_message_with_parts(conversation_id, role, content, "[]")
53 .await
54 }
55
56 pub async fn save_message_with_parts(
62 &self,
63 conversation_id: ConversationId,
64 role: &str,
65 content: &str,
66 parts_json: &str,
67 ) -> Result<MessageId, MemoryError> {
68 self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
69 .await
70 }
71
72 pub async fn save_message_with_metadata(
78 &self,
79 conversation_id: ConversationId,
80 role: &str,
81 content: &str,
82 parts_json: &str,
83 agent_visible: bool,
84 user_visible: bool,
85 ) -> Result<MessageId, MemoryError> {
86 let row: (MessageId,) = sqlx::query_as(
87 "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
88 VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
89 )
90 .bind(conversation_id)
91 .bind(role)
92 .bind(content)
93 .bind(parts_json)
94 .bind(i64::from(agent_visible))
95 .bind(i64::from(user_visible))
96 .fetch_one(&self.pool)
97 .await?;
98 Ok(row.0)
99 }
100
101 pub async fn load_history(
107 &self,
108 conversation_id: ConversationId,
109 limit: u32,
110 ) -> Result<Vec<Message>, MemoryError> {
111 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
112 "SELECT role, content, parts, agent_visible, user_visible FROM (\
113 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
114 WHERE conversation_id = ? \
115 ORDER BY id DESC \
116 LIMIT ?\
117 ) ORDER BY id ASC",
118 )
119 .bind(conversation_id)
120 .bind(limit)
121 .fetch_all(&self.pool)
122 .await?;
123
124 let messages = rows
125 .into_iter()
126 .map(
127 |(role_str, content, parts_json, agent_visible, user_visible)| {
128 let parts: Vec<MessagePart> = if parts_json == "[]" {
129 vec![]
130 } else {
131 serde_json::from_str(&parts_json).unwrap_or_default()
132 };
133 Message {
134 role: parse_role(&role_str),
135 content,
136 parts,
137 metadata: MessageMetadata {
138 agent_visible: agent_visible != 0,
139 user_visible: user_visible != 0,
140 compacted_at: None,
141 },
142 }
143 },
144 )
145 .collect();
146 Ok(messages)
147 }
148
149 pub async fn load_history_filtered(
157 &self,
158 conversation_id: ConversationId,
159 limit: u32,
160 agent_visible: Option<bool>,
161 user_visible: Option<bool>,
162 ) -> Result<Vec<Message>, MemoryError> {
163 let av = agent_visible.map(i64::from);
164 let uv = user_visible.map(i64::from);
165
166 let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
167 "WITH recent AS (\
168 SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
169 WHERE conversation_id = ? \
170 AND (? IS NULL OR agent_visible = ?) \
171 AND (? IS NULL OR user_visible = ?) \
172 ORDER BY id DESC \
173 LIMIT ?\
174 ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
175 )
176 .bind(conversation_id)
177 .bind(av)
178 .bind(av)
179 .bind(uv)
180 .bind(uv)
181 .bind(limit)
182 .fetch_all(&self.pool)
183 .await?;
184
185 let messages = rows
186 .into_iter()
187 .map(
188 |(role_str, content, parts_json, agent_visible, user_visible)| {
189 let parts: Vec<MessagePart> = if parts_json == "[]" {
190 vec![]
191 } else {
192 serde_json::from_str(&parts_json).unwrap_or_default()
193 };
194 Message {
195 role: parse_role(&role_str),
196 content,
197 parts,
198 metadata: MessageMetadata {
199 agent_visible: agent_visible != 0,
200 user_visible: user_visible != 0,
201 compacted_at: None,
202 },
203 }
204 },
205 )
206 .collect();
207 Ok(messages)
208 }
209
210 pub async fn replace_conversation(
222 &self,
223 conversation_id: ConversationId,
224 compacted_range: std::ops::RangeInclusive<MessageId>,
225 summary_role: &str,
226 summary_content: &str,
227 ) -> Result<MessageId, MemoryError> {
228 let now = {
229 let secs = std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .unwrap_or_default()
232 .as_secs();
233 format!("{secs}")
234 };
235 let start_id = compacted_range.start().0;
236 let end_id = compacted_range.end().0;
237
238 let mut tx = self.pool.begin().await?;
239
240 sqlx::query(
241 "UPDATE messages SET agent_visible = 0, compacted_at = ? \
242 WHERE conversation_id = ? AND id >= ? AND id <= ?",
243 )
244 .bind(&now)
245 .bind(conversation_id)
246 .bind(start_id)
247 .bind(end_id)
248 .execute(&mut *tx)
249 .await?;
250
251 let row: (MessageId,) = sqlx::query_as(
252 "INSERT INTO messages \
253 (conversation_id, role, content, parts, agent_visible, user_visible) \
254 VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
255 )
256 .bind(conversation_id)
257 .bind(summary_role)
258 .bind(summary_content)
259 .fetch_one(&mut *tx)
260 .await?;
261
262 tx.commit().await?;
263
264 Ok(row.0)
265 }
266
267 pub async fn oldest_message_ids(
273 &self,
274 conversation_id: ConversationId,
275 n: u32,
276 ) -> Result<Vec<MessageId>, MemoryError> {
277 let rows: Vec<(MessageId,)> = sqlx::query_as(
278 "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
279 )
280 .bind(conversation_id)
281 .bind(n)
282 .fetch_all(&self.pool)
283 .await?;
284 Ok(rows.into_iter().map(|r| r.0).collect())
285 }
286
287 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
293 let row: Option<(ConversationId,)> =
294 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
295 .fetch_optional(&self.pool)
296 .await?;
297 Ok(row.map(|r| r.0))
298 }
299
300 pub async fn message_by_id(
306 &self,
307 message_id: MessageId,
308 ) -> Result<Option<Message>, MemoryError> {
309 let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
310 "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
311 )
312 .bind(message_id)
313 .fetch_optional(&self.pool)
314 .await?;
315
316 Ok(row.map(
317 |(role_str, content, parts_json, agent_visible, user_visible)| {
318 let parts: Vec<MessagePart> = if parts_json == "[]" {
319 vec![]
320 } else {
321 serde_json::from_str(&parts_json).unwrap_or_default()
322 };
323 Message {
324 role: parse_role(&role_str),
325 content,
326 parts,
327 metadata: MessageMetadata {
328 agent_visible: agent_visible != 0,
329 user_visible: user_visible != 0,
330 compacted_at: None,
331 },
332 }
333 },
334 ))
335 }
336
337 pub async fn messages_by_ids(
343 &self,
344 ids: &[MessageId],
345 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
346 if ids.is_empty() {
347 return Ok(Vec::new());
348 }
349
350 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
351
352 let query = format!(
353 "SELECT id, role, content, parts FROM messages \
354 WHERE id IN ({placeholders}) AND agent_visible = 1"
355 );
356 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
357 for &id in ids {
358 q = q.bind(id);
359 }
360
361 let rows = q.fetch_all(&self.pool).await?;
362
363 Ok(rows
364 .into_iter()
365 .map(|(id, role_str, content, parts_json)| {
366 let parts: Vec<MessagePart> = if parts_json == "[]" {
367 vec![]
368 } else {
369 serde_json::from_str(&parts_json).unwrap_or_default()
370 };
371 (
372 id,
373 Message {
374 role: parse_role(&role_str),
375 content,
376 parts,
377 metadata: MessageMetadata::default(),
378 },
379 )
380 })
381 .collect())
382 }
383
384 pub async fn unembedded_message_ids(
390 &self,
391 limit: Option<usize>,
392 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
393 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
394
395 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
396 "SELECT m.id, m.conversation_id, m.role, m.content \
397 FROM messages m \
398 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
399 WHERE em.id IS NULL \
400 ORDER BY m.id ASC \
401 LIMIT ?",
402 )
403 .bind(effective_limit)
404 .fetch_all(&self.pool)
405 .await?;
406
407 Ok(rows)
408 }
409
410 pub async fn count_messages(
416 &self,
417 conversation_id: ConversationId,
418 ) -> Result<i64, MemoryError> {
419 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
420 .bind(conversation_id)
421 .fetch_one(&self.pool)
422 .await?;
423 Ok(row.0)
424 }
425
426 pub async fn count_messages_after(
432 &self,
433 conversation_id: ConversationId,
434 after_id: MessageId,
435 ) -> Result<i64, MemoryError> {
436 let row: (i64,) =
437 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
438 .bind(conversation_id)
439 .bind(after_id)
440 .fetch_one(&self.pool)
441 .await?;
442 Ok(row.0)
443 }
444
445 pub async fn keyword_search(
454 &self,
455 query: &str,
456 limit: usize,
457 conversation_id: Option<ConversationId>,
458 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
459 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
460
461 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
462 sqlx::query_as(
463 "SELECT m.id, -rank AS score \
464 FROM messages_fts f \
465 JOIN messages m ON m.id = f.rowid \
466 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
467 ORDER BY rank \
468 LIMIT ?",
469 )
470 .bind(query)
471 .bind(cid)
472 .bind(effective_limit)
473 .fetch_all(&self.pool)
474 .await?
475 } else {
476 sqlx::query_as(
477 "SELECT m.id, -rank AS score \
478 FROM messages_fts f \
479 JOIN messages m ON m.id = f.rowid \
480 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
481 ORDER BY rank \
482 LIMIT ?",
483 )
484 .bind(query)
485 .bind(effective_limit)
486 .fetch_all(&self.pool)
487 .await?
488 };
489
490 Ok(rows)
491 }
492
493 pub async fn message_timestamps(
501 &self,
502 ids: &[MessageId],
503 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
504 if ids.is_empty() {
505 return Ok(std::collections::HashMap::new());
506 }
507
508 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
509 let query = format!(
510 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
511 FROM messages WHERE id IN ({placeholders})"
512 );
513 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
514 for &id in ids {
515 q = q.bind(id);
516 }
517
518 let rows = q.fetch_all(&self.pool).await?;
519 Ok(rows.into_iter().collect())
520 }
521
522 pub async fn load_messages_range(
528 &self,
529 conversation_id: ConversationId,
530 after_message_id: MessageId,
531 limit: usize,
532 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
533 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
534
535 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
536 "SELECT id, role, content FROM messages \
537 WHERE conversation_id = ? AND id > ? \
538 ORDER BY id ASC LIMIT ?",
539 )
540 .bind(conversation_id)
541 .bind(after_message_id)
542 .bind(effective_limit)
543 .fetch_all(&self.pool)
544 .await?;
545
546 Ok(rows)
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 async fn test_store() -> SqliteStore {
555 SqliteStore::new(":memory:").await.unwrap()
556 }
557
558 #[tokio::test]
559 async fn create_conversation_returns_id() {
560 let store = test_store().await;
561 let id1 = store.create_conversation().await.unwrap();
562 let id2 = store.create_conversation().await.unwrap();
563 assert_eq!(id1, ConversationId(1));
564 assert_eq!(id2, ConversationId(2));
565 }
566
567 #[tokio::test]
568 async fn save_and_load_messages() {
569 let store = test_store().await;
570 let cid = store.create_conversation().await.unwrap();
571
572 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
573 let msg_id2 = store
574 .save_message(cid, "assistant", "hi there")
575 .await
576 .unwrap();
577
578 assert_eq!(msg_id1, MessageId(1));
579 assert_eq!(msg_id2, MessageId(2));
580
581 let history = store.load_history(cid, 50).await.unwrap();
582 assert_eq!(history.len(), 2);
583 assert_eq!(history[0].role, Role::User);
584 assert_eq!(history[0].content, "hello");
585 assert_eq!(history[1].role, Role::Assistant);
586 assert_eq!(history[1].content, "hi there");
587 }
588
589 #[tokio::test]
590 async fn load_history_respects_limit() {
591 let store = test_store().await;
592 let cid = store.create_conversation().await.unwrap();
593
594 for i in 0..10 {
595 store
596 .save_message(cid, "user", &format!("msg {i}"))
597 .await
598 .unwrap();
599 }
600
601 let history = store.load_history(cid, 3).await.unwrap();
602 assert_eq!(history.len(), 3);
603 assert_eq!(history[0].content, "msg 7");
604 assert_eq!(history[1].content, "msg 8");
605 assert_eq!(history[2].content, "msg 9");
606 }
607
608 #[tokio::test]
609 async fn latest_conversation_id_empty() {
610 let store = test_store().await;
611 assert!(store.latest_conversation_id().await.unwrap().is_none());
612 }
613
614 #[tokio::test]
615 async fn latest_conversation_id_returns_newest() {
616 let store = test_store().await;
617 store.create_conversation().await.unwrap();
618 let id2 = store.create_conversation().await.unwrap();
619 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
620 }
621
622 #[tokio::test]
623 async fn messages_isolated_per_conversation() {
624 let store = test_store().await;
625 let cid1 = store.create_conversation().await.unwrap();
626 let cid2 = store.create_conversation().await.unwrap();
627
628 store.save_message(cid1, "user", "conv1").await.unwrap();
629 store.save_message(cid2, "user", "conv2").await.unwrap();
630
631 let h1 = store.load_history(cid1, 50).await.unwrap();
632 let h2 = store.load_history(cid2, 50).await.unwrap();
633 assert_eq!(h1.len(), 1);
634 assert_eq!(h1[0].content, "conv1");
635 assert_eq!(h2.len(), 1);
636 assert_eq!(h2[0].content, "conv2");
637 }
638
639 #[tokio::test]
640 async fn pool_accessor_returns_valid_pool() {
641 let store = test_store().await;
642 let pool = store.pool();
643 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
644 assert_eq!(row.0, 1);
645 }
646
647 #[tokio::test]
648 async fn embeddings_metadata_table_exists() {
649 let store = test_store().await;
650 let result: (i64,) = sqlx::query_as(
651 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
652 )
653 .fetch_one(store.pool())
654 .await
655 .unwrap();
656 assert_eq!(result.0, 1);
657 }
658
659 #[tokio::test]
660 async fn cascade_delete_removes_embeddings_metadata() {
661 let store = test_store().await;
662 let pool = store.pool();
663
664 let cid = store.create_conversation().await.unwrap();
665 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
666
667 let point_id = uuid::Uuid::new_v4().to_string();
668 sqlx::query(
669 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
670 VALUES (?, ?, ?)",
671 )
672 .bind(msg_id)
673 .bind(&point_id)
674 .bind(768_i64)
675 .execute(pool)
676 .await
677 .unwrap();
678
679 let before: (i64,) =
680 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
681 .bind(msg_id)
682 .fetch_one(pool)
683 .await
684 .unwrap();
685 assert_eq!(before.0, 1);
686
687 sqlx::query("DELETE FROM messages WHERE id = ?")
688 .bind(msg_id)
689 .execute(pool)
690 .await
691 .unwrap();
692
693 let after: (i64,) =
694 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
695 .bind(msg_id)
696 .fetch_one(pool)
697 .await
698 .unwrap();
699 assert_eq!(after.0, 0);
700 }
701
702 #[tokio::test]
703 async fn messages_by_ids_batch_fetch() {
704 let store = test_store().await;
705 let cid = store.create_conversation().await.unwrap();
706 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
707 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
708 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
709
710 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
711 assert_eq!(results.len(), 2);
712 assert_eq!(results[0].0, id1);
713 assert_eq!(results[0].1.content, "hello");
714 assert_eq!(results[1].0, id2);
715 assert_eq!(results[1].1.content, "hi");
716 }
717
718 #[tokio::test]
719 async fn messages_by_ids_empty_input() {
720 let store = test_store().await;
721 let results = store.messages_by_ids(&[]).await.unwrap();
722 assert!(results.is_empty());
723 }
724
725 #[tokio::test]
726 async fn messages_by_ids_nonexistent() {
727 let store = test_store().await;
728 let results = store
729 .messages_by_ids(&[MessageId(999), MessageId(1000)])
730 .await
731 .unwrap();
732 assert!(results.is_empty());
733 }
734
735 #[tokio::test]
736 async fn message_by_id_fetches_existing() {
737 let store = test_store().await;
738 let cid = store.create_conversation().await.unwrap();
739 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
740
741 let msg = store.message_by_id(msg_id).await.unwrap();
742 assert!(msg.is_some());
743 let msg = msg.unwrap();
744 assert_eq!(msg.role, Role::User);
745 assert_eq!(msg.content, "hello");
746 }
747
748 #[tokio::test]
749 async fn message_by_id_returns_none_for_nonexistent() {
750 let store = test_store().await;
751 let msg = store.message_by_id(MessageId(999)).await.unwrap();
752 assert!(msg.is_none());
753 }
754
755 #[tokio::test]
756 async fn unembedded_message_ids_returns_all_when_none_embedded() {
757 let store = test_store().await;
758 let cid = store.create_conversation().await.unwrap();
759
760 store.save_message(cid, "user", "msg1").await.unwrap();
761 store.save_message(cid, "assistant", "msg2").await.unwrap();
762
763 let unembedded = store.unembedded_message_ids(None).await.unwrap();
764 assert_eq!(unembedded.len(), 2);
765 assert_eq!(unembedded[0].3, "msg1");
766 assert_eq!(unembedded[1].3, "msg2");
767 }
768
769 #[tokio::test]
770 async fn unembedded_message_ids_excludes_embedded() {
771 let store = test_store().await;
772 let pool = store.pool();
773 let cid = store.create_conversation().await.unwrap();
774
775 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
776 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
777
778 let point_id = uuid::Uuid::new_v4().to_string();
779 sqlx::query(
780 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
781 VALUES (?, ?, ?)",
782 )
783 .bind(msg_id1)
784 .bind(&point_id)
785 .bind(768_i64)
786 .execute(pool)
787 .await
788 .unwrap();
789
790 let unembedded = store.unembedded_message_ids(None).await.unwrap();
791 assert_eq!(unembedded.len(), 1);
792 assert_eq!(unembedded[0].0, msg_id2);
793 assert_eq!(unembedded[0].3, "msg2");
794 }
795
796 #[tokio::test]
797 async fn unembedded_message_ids_respects_limit() {
798 let store = test_store().await;
799 let cid = store.create_conversation().await.unwrap();
800
801 for i in 0..10 {
802 store
803 .save_message(cid, "user", &format!("msg{i}"))
804 .await
805 .unwrap();
806 }
807
808 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
809 assert_eq!(unembedded.len(), 3);
810 }
811
812 #[tokio::test]
813 async fn count_messages_returns_correct_count() {
814 let store = test_store().await;
815 let cid = store.create_conversation().await.unwrap();
816
817 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
818
819 store.save_message(cid, "user", "msg1").await.unwrap();
820 store.save_message(cid, "assistant", "msg2").await.unwrap();
821
822 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
823 }
824
825 #[tokio::test]
826 async fn count_messages_after_filters_correctly() {
827 let store = test_store().await;
828 let cid = store.create_conversation().await.unwrap();
829
830 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
831 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
832 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
833
834 assert_eq!(
835 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
836 3
837 );
838 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
839 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
840 }
841
842 #[tokio::test]
843 async fn load_messages_range_basic() {
844 let store = test_store().await;
845 let cid = store.create_conversation().await.unwrap();
846
847 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
848 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
849 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
850
851 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
852 assert_eq!(msgs.len(), 2);
853 assert_eq!(msgs[0].0, msg_id2);
854 assert_eq!(msgs[0].2, "msg2");
855 assert_eq!(msgs[1].0, msg_id3);
856 assert_eq!(msgs[1].2, "msg3");
857 }
858
859 #[tokio::test]
860 async fn load_messages_range_respects_limit() {
861 let store = test_store().await;
862 let cid = store.create_conversation().await.unwrap();
863
864 store.save_message(cid, "user", "msg1").await.unwrap();
865 store.save_message(cid, "assistant", "msg2").await.unwrap();
866 store.save_message(cid, "user", "msg3").await.unwrap();
867
868 let msgs = store
869 .load_messages_range(cid, MessageId(0), 2)
870 .await
871 .unwrap();
872 assert_eq!(msgs.len(), 2);
873 }
874
875 #[tokio::test]
876 async fn keyword_search_basic() {
877 let store = test_store().await;
878 let cid = store.create_conversation().await.unwrap();
879
880 store
881 .save_message(cid, "user", "rust programming language")
882 .await
883 .unwrap();
884 store
885 .save_message(cid, "assistant", "python is great too")
886 .await
887 .unwrap();
888 store
889 .save_message(cid, "user", "I love rust and cargo")
890 .await
891 .unwrap();
892
893 let results = store.keyword_search("rust", 10, None).await.unwrap();
894 assert_eq!(results.len(), 2);
895 assert!(results.iter().all(|(_, score)| *score > 0.0));
896 }
897
898 #[tokio::test]
899 async fn keyword_search_with_conversation_filter() {
900 let store = test_store().await;
901 let cid1 = store.create_conversation().await.unwrap();
902 let cid2 = store.create_conversation().await.unwrap();
903
904 store
905 .save_message(cid1, "user", "hello world")
906 .await
907 .unwrap();
908 store
909 .save_message(cid2, "user", "hello universe")
910 .await
911 .unwrap();
912
913 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
914 assert_eq!(results.len(), 1);
915 }
916
917 #[tokio::test]
918 async fn keyword_search_no_match() {
919 let store = test_store().await;
920 let cid = store.create_conversation().await.unwrap();
921
922 store
923 .save_message(cid, "user", "hello world")
924 .await
925 .unwrap();
926
927 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
928 assert!(results.is_empty());
929 }
930
931 #[tokio::test]
932 async fn keyword_search_respects_limit() {
933 let store = test_store().await;
934 let cid = store.create_conversation().await.unwrap();
935
936 for i in 0..10 {
937 store
938 .save_message(cid, "user", &format!("test message {i}"))
939 .await
940 .unwrap();
941 }
942
943 let results = store.keyword_search("test", 3, None).await.unwrap();
944 assert_eq!(results.len(), 3);
945 }
946
947 #[tokio::test]
948 async fn save_message_with_metadata_stores_visibility() {
949 let store = test_store().await;
950 let cid = store.create_conversation().await.unwrap();
951
952 let id = store
953 .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
954 .await
955 .unwrap();
956
957 let history = store.load_history(cid, 10).await.unwrap();
958 assert_eq!(history.len(), 1);
959 assert!(!history[0].metadata.agent_visible);
960 assert!(history[0].metadata.user_visible);
961 assert_eq!(id, MessageId(1));
962 }
963
964 #[tokio::test]
965 async fn load_history_filtered_by_agent_visible() {
966 let store = test_store().await;
967 let cid = store.create_conversation().await.unwrap();
968
969 store
970 .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
971 .await
972 .unwrap();
973 store
974 .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
975 .await
976 .unwrap();
977
978 let agent_msgs = store
979 .load_history_filtered(cid, 50, Some(true), None)
980 .await
981 .unwrap();
982 assert_eq!(agent_msgs.len(), 1);
983 assert_eq!(agent_msgs[0].content, "visible to agent");
984 }
985
986 #[tokio::test]
987 async fn load_history_filtered_by_user_visible() {
988 let store = test_store().await;
989 let cid = store.create_conversation().await.unwrap();
990
991 store
992 .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
993 .await
994 .unwrap();
995 store
996 .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
997 .await
998 .unwrap();
999
1000 let user_msgs = store
1001 .load_history_filtered(cid, 50, None, Some(true))
1002 .await
1003 .unwrap();
1004 assert_eq!(user_msgs.len(), 1);
1005 assert_eq!(user_msgs[0].content, "user sees this");
1006 }
1007
1008 #[tokio::test]
1009 async fn load_history_filtered_no_filter_returns_all() {
1010 let store = test_store().await;
1011 let cid = store.create_conversation().await.unwrap();
1012
1013 store
1014 .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1015 .await
1016 .unwrap();
1017 store
1018 .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1019 .await
1020 .unwrap();
1021
1022 let all_msgs = store
1023 .load_history_filtered(cid, 50, None, None)
1024 .await
1025 .unwrap();
1026 assert_eq!(all_msgs.len(), 2);
1027 }
1028
1029 #[tokio::test]
1030 async fn replace_conversation_marks_originals_and_inserts_summary() {
1031 let store = test_store().await;
1032 let cid = store.create_conversation().await.unwrap();
1033
1034 let id1 = store.save_message(cid, "user", "first").await.unwrap();
1035 let id2 = store
1036 .save_message(cid, "assistant", "second")
1037 .await
1038 .unwrap();
1039 let id3 = store.save_message(cid, "user", "third").await.unwrap();
1040
1041 let summary_id = store
1042 .replace_conversation(cid, id1..=id2, "system", "summary text")
1043 .await
1044 .unwrap();
1045
1046 let all = store.load_history(cid, 50).await.unwrap();
1048 let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1050 assert!(!by_id1.metadata.agent_visible);
1051 assert!(by_id1.metadata.user_visible);
1052
1053 let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1054 assert!(!by_id2.metadata.agent_visible);
1055
1056 let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1057 assert!(by_id3.metadata.agent_visible);
1058
1059 let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1061 assert!(summary.metadata.agent_visible);
1062 assert!(!summary.metadata.user_visible);
1063 assert!(summary_id > id3);
1064 }
1065
1066 #[tokio::test]
1067 async fn oldest_message_ids_returns_in_order() {
1068 let store = test_store().await;
1069 let cid = store.create_conversation().await.unwrap();
1070
1071 let id1 = store.save_message(cid, "user", "a").await.unwrap();
1072 let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1073 let id3 = store.save_message(cid, "user", "c").await.unwrap();
1074
1075 let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1076 assert_eq!(ids, vec![id1, id2]);
1077 assert!(ids[0] < ids[1]);
1078
1079 let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1080 assert_eq!(all_ids, vec![id1, id2, id3]);
1081 }
1082
1083 #[tokio::test]
1084 async fn message_metadata_default_both_visible() {
1085 let store = test_store().await;
1086 let cid = store.create_conversation().await.unwrap();
1087
1088 store.save_message(cid, "user", "normal").await.unwrap();
1089
1090 let history = store.load_history(cid, 10).await.unwrap();
1091 assert!(history[0].metadata.agent_visible);
1092 assert!(history[0].metadata.user_visible);
1093 assert!(history[0].metadata.compacted_at.is_none());
1094 }
1095
1096 #[tokio::test]
1097 async fn load_history_empty_parts_json_fast_path() {
1098 let store = test_store().await;
1099 let cid = store.create_conversation().await.unwrap();
1100
1101 store
1102 .save_message_with_parts(cid, "user", "hello", "[]")
1103 .await
1104 .unwrap();
1105
1106 let history = store.load_history(cid, 10).await.unwrap();
1107 assert_eq!(history.len(), 1);
1108 assert!(
1109 history[0].parts.is_empty(),
1110 "\"[]\" fast-path must yield empty parts Vec"
1111 );
1112 }
1113
1114 #[tokio::test]
1115 async fn load_history_non_empty_parts_json_parsed() {
1116 let store = test_store().await;
1117 let cid = store.create_conversation().await.unwrap();
1118
1119 let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1120 tool_use_id: "t1".into(),
1121 content: "result".into(),
1122 is_error: false,
1123 }])
1124 .unwrap();
1125
1126 store
1127 .save_message_with_parts(cid, "user", "hello", &parts_json)
1128 .await
1129 .unwrap();
1130
1131 let history = store.load_history(cid, 10).await.unwrap();
1132 assert_eq!(history.len(), 1);
1133 assert_eq!(history[0].parts.len(), 1);
1134 assert!(
1135 matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1136 );
1137 }
1138
1139 #[tokio::test]
1140 async fn message_by_id_empty_parts_json_fast_path() {
1141 let store = test_store().await;
1142 let cid = store.create_conversation().await.unwrap();
1143
1144 let id = store
1145 .save_message_with_parts(cid, "user", "msg", "[]")
1146 .await
1147 .unwrap();
1148
1149 let msg = store.message_by_id(id).await.unwrap().unwrap();
1150 assert!(
1151 msg.parts.is_empty(),
1152 "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1153 );
1154 }
1155
1156 #[tokio::test]
1157 async fn messages_by_ids_empty_parts_json_fast_path() {
1158 let store = test_store().await;
1159 let cid = store.create_conversation().await.unwrap();
1160
1161 let id = store
1162 .save_message_with_parts(cid, "user", "msg", "[]")
1163 .await
1164 .unwrap();
1165
1166 let results = store.messages_by_ids(&[id]).await.unwrap();
1167 assert_eq!(results.len(), 1);
1168 assert!(
1169 results[0].1.parts.is_empty(),
1170 "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1171 );
1172 }
1173
1174 #[tokio::test]
1175 async fn load_history_filtered_empty_parts_json_fast_path() {
1176 let store = test_store().await;
1177 let cid = store.create_conversation().await.unwrap();
1178
1179 store
1180 .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1181 .await
1182 .unwrap();
1183
1184 let msgs = store
1185 .load_history_filtered(cid, 10, Some(true), None)
1186 .await
1187 .unwrap();
1188 assert_eq!(msgs.len(), 1);
1189 assert!(
1190 msgs[0].parts.is_empty(),
1191 "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1192 );
1193 }
1194}