Skip to main content

zeph_memory/sqlite/
messages.rs

1use zeph_llm::provider::{Message, MessagePart, Role};
2
3use super::SqliteStore;
4use crate::error::MemoryError;
5use crate::types::{ConversationId, MessageId};
6
7fn parse_role(s: &str) -> Role {
8    match s {
9        "assistant" => Role::Assistant,
10        "system" => Role::System,
11        _ => Role::User,
12    }
13}
14
15#[must_use]
16pub fn role_str(role: Role) -> &'static str {
17    match role {
18        Role::System => "system",
19        Role::User => "user",
20        Role::Assistant => "assistant",
21    }
22}
23
24impl SqliteStore {
25    /// Create a new conversation and return its ID.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the insert fails.
30    pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
31        let row: (ConversationId,) =
32            sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
33                .fetch_one(&self.pool)
34                .await?;
35        Ok(row.0)
36    }
37
38    /// Save a message to the given conversation and return the message ID.
39    ///
40    /// # Errors
41    ///
42    /// Returns an error if the insert fails.
43    pub async fn save_message(
44        &self,
45        conversation_id: ConversationId,
46        role: &str,
47        content: &str,
48    ) -> Result<MessageId, MemoryError> {
49        self.save_message_with_parts(conversation_id, role, content, "[]")
50            .await
51    }
52
53    /// Save a message with structured parts JSON.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the insert fails.
58    pub async fn save_message_with_parts(
59        &self,
60        conversation_id: ConversationId,
61        role: &str,
62        content: &str,
63        parts_json: &str,
64    ) -> Result<MessageId, MemoryError> {
65        let row: (MessageId,) = sqlx::query_as(
66            "INSERT INTO messages (conversation_id, role, content, parts) VALUES (?, ?, ?, ?) RETURNING id",
67        )
68        .bind(conversation_id)
69        .bind(role)
70        .bind(content)
71        .bind(parts_json)
72        .fetch_one(&self.pool)
73        .await
74        ?;
75        Ok(row.0)
76    }
77
78    /// Load the most recent messages for a conversation, up to `limit`.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the query fails.
83    pub async fn load_history(
84        &self,
85        conversation_id: ConversationId,
86        limit: u32,
87    ) -> Result<Vec<Message>, MemoryError> {
88        let rows: Vec<(String, String, String)> = sqlx::query_as(
89            "SELECT role, content, parts FROM (\
90                SELECT role, content, parts, id FROM messages \
91                WHERE conversation_id = ? \
92                ORDER BY id DESC \
93                LIMIT ?\
94             ) ORDER BY id ASC",
95        )
96        .bind(conversation_id)
97        .bind(limit)
98        .fetch_all(&self.pool)
99        .await?;
100
101        let messages = rows
102            .into_iter()
103            .map(|(role_str, content, parts_json)| {
104                let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
105                Message {
106                    role: parse_role(&role_str),
107                    content,
108                    parts,
109                }
110            })
111            .collect();
112        Ok(messages)
113    }
114
115    /// Return the ID of the most recent conversation, if any.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the query fails.
120    pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
121        let row: Option<(ConversationId,)> =
122            sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
123                .fetch_optional(&self.pool)
124                .await?;
125        Ok(row.map(|r| r.0))
126    }
127
128    /// Fetch a single message by its ID.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the query fails.
133    pub async fn message_by_id(
134        &self,
135        message_id: MessageId,
136    ) -> Result<Option<Message>, MemoryError> {
137        let row: Option<(String, String, String)> =
138            sqlx::query_as("SELECT role, content, parts FROM messages WHERE id = ?")
139                .bind(message_id)
140                .fetch_optional(&self.pool)
141                .await?;
142
143        Ok(row.map(|(role_str, content, parts_json)| {
144            let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
145            Message {
146                role: parse_role(&role_str),
147                content,
148                parts,
149            }
150        }))
151    }
152
153    /// Fetch messages by a list of IDs in a single query.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if the query fails.
158    pub async fn messages_by_ids(
159        &self,
160        ids: &[MessageId],
161    ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
162        if ids.is_empty() {
163            return Ok(Vec::new());
164        }
165
166        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
167
168        let query =
169            format!("SELECT id, role, content, parts FROM messages WHERE id IN ({placeholders})");
170        let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
171        for &id in ids {
172            q = q.bind(id);
173        }
174
175        let rows = q.fetch_all(&self.pool).await?;
176
177        Ok(rows
178            .into_iter()
179            .map(|(id, role_str, content, parts_json)| {
180                let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
181                (
182                    id,
183                    Message {
184                        role: parse_role(&role_str),
185                        content,
186                        parts,
187                    },
188                )
189            })
190            .collect())
191    }
192
193    /// Return message IDs and content for messages without embeddings.
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if the query fails.
198    pub async fn unembedded_message_ids(
199        &self,
200        limit: Option<usize>,
201    ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
202        let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
203
204        let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
205            "SELECT m.id, m.conversation_id, m.role, m.content \
206             FROM messages m \
207             LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
208             WHERE em.id IS NULL \
209             ORDER BY m.id ASC \
210             LIMIT ?",
211        )
212        .bind(effective_limit)
213        .fetch_all(&self.pool)
214        .await?;
215
216        Ok(rows)
217    }
218
219    /// Count the number of messages in a conversation.
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if the query fails.
224    pub async fn count_messages(
225        &self,
226        conversation_id: ConversationId,
227    ) -> Result<i64, MemoryError> {
228        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
229            .bind(conversation_id)
230            .fetch_one(&self.pool)
231            .await?;
232        Ok(row.0)
233    }
234
235    /// Count messages in a conversation with id greater than `after_id`.
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if the query fails.
240    pub async fn count_messages_after(
241        &self,
242        conversation_id: ConversationId,
243        after_id: MessageId,
244    ) -> Result<i64, MemoryError> {
245        let row: (i64,) =
246            sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
247                .bind(conversation_id)
248                .bind(after_id)
249                .fetch_one(&self.pool)
250                .await?;
251        Ok(row.0)
252    }
253
254    /// Full-text keyword search over messages using FTS5.
255    ///
256    /// Returns message IDs with BM25 relevance scores (lower = more relevant,
257    /// negated to positive for consistency with vector scores).
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if the query fails.
262    pub async fn keyword_search(
263        &self,
264        query: &str,
265        limit: usize,
266        conversation_id: Option<ConversationId>,
267    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
268        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
269
270        let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
271            sqlx::query_as(
272                "SELECT m.id, -rank AS score \
273                 FROM messages_fts f \
274                 JOIN messages m ON m.id = f.rowid \
275                 WHERE messages_fts MATCH ? AND m.conversation_id = ? \
276                 ORDER BY rank \
277                 LIMIT ?",
278            )
279            .bind(query)
280            .bind(cid)
281            .bind(effective_limit)
282            .fetch_all(&self.pool)
283            .await?
284        } else {
285            sqlx::query_as(
286                "SELECT f.rowid, -rank AS score \
287                 FROM messages_fts f \
288                 WHERE messages_fts MATCH ? \
289                 ORDER BY rank \
290                 LIMIT ?",
291            )
292            .bind(query)
293            .bind(effective_limit)
294            .fetch_all(&self.pool)
295            .await?
296        };
297
298        Ok(rows)
299    }
300
301    /// Load a range of messages after a given message ID.
302    ///
303    /// # Errors
304    ///
305    /// Returns an error if the query fails.
306    pub async fn load_messages_range(
307        &self,
308        conversation_id: ConversationId,
309        after_message_id: MessageId,
310        limit: usize,
311    ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
312        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
313
314        let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
315            "SELECT id, role, content FROM messages \
316             WHERE conversation_id = ? AND id > ? \
317             ORDER BY id ASC LIMIT ?",
318        )
319        .bind(conversation_id)
320        .bind(after_message_id)
321        .bind(effective_limit)
322        .fetch_all(&self.pool)
323        .await?;
324
325        Ok(rows)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    async fn test_store() -> SqliteStore {
334        SqliteStore::new(":memory:").await.unwrap()
335    }
336
337    #[tokio::test]
338    async fn create_conversation_returns_id() {
339        let store = test_store().await;
340        let id1 = store.create_conversation().await.unwrap();
341        let id2 = store.create_conversation().await.unwrap();
342        assert_eq!(id1, ConversationId(1));
343        assert_eq!(id2, ConversationId(2));
344    }
345
346    #[tokio::test]
347    async fn save_and_load_messages() {
348        let store = test_store().await;
349        let cid = store.create_conversation().await.unwrap();
350
351        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
352        let msg_id2 = store
353            .save_message(cid, "assistant", "hi there")
354            .await
355            .unwrap();
356
357        assert_eq!(msg_id1, MessageId(1));
358        assert_eq!(msg_id2, MessageId(2));
359
360        let history = store.load_history(cid, 50).await.unwrap();
361        assert_eq!(history.len(), 2);
362        assert_eq!(history[0].role, Role::User);
363        assert_eq!(history[0].content, "hello");
364        assert_eq!(history[1].role, Role::Assistant);
365        assert_eq!(history[1].content, "hi there");
366    }
367
368    #[tokio::test]
369    async fn load_history_respects_limit() {
370        let store = test_store().await;
371        let cid = store.create_conversation().await.unwrap();
372
373        for i in 0..10 {
374            store
375                .save_message(cid, "user", &format!("msg {i}"))
376                .await
377                .unwrap();
378        }
379
380        let history = store.load_history(cid, 3).await.unwrap();
381        assert_eq!(history.len(), 3);
382        assert_eq!(history[0].content, "msg 7");
383        assert_eq!(history[1].content, "msg 8");
384        assert_eq!(history[2].content, "msg 9");
385    }
386
387    #[tokio::test]
388    async fn latest_conversation_id_empty() {
389        let store = test_store().await;
390        assert!(store.latest_conversation_id().await.unwrap().is_none());
391    }
392
393    #[tokio::test]
394    async fn latest_conversation_id_returns_newest() {
395        let store = test_store().await;
396        store.create_conversation().await.unwrap();
397        let id2 = store.create_conversation().await.unwrap();
398        assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
399    }
400
401    #[tokio::test]
402    async fn messages_isolated_per_conversation() {
403        let store = test_store().await;
404        let cid1 = store.create_conversation().await.unwrap();
405        let cid2 = store.create_conversation().await.unwrap();
406
407        store.save_message(cid1, "user", "conv1").await.unwrap();
408        store.save_message(cid2, "user", "conv2").await.unwrap();
409
410        let h1 = store.load_history(cid1, 50).await.unwrap();
411        let h2 = store.load_history(cid2, 50).await.unwrap();
412        assert_eq!(h1.len(), 1);
413        assert_eq!(h1[0].content, "conv1");
414        assert_eq!(h2.len(), 1);
415        assert_eq!(h2[0].content, "conv2");
416    }
417
418    #[tokio::test]
419    async fn pool_accessor_returns_valid_pool() {
420        let store = test_store().await;
421        let pool = store.pool();
422        let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
423        assert_eq!(row.0, 1);
424    }
425
426    #[tokio::test]
427    async fn embeddings_metadata_table_exists() {
428        let store = test_store().await;
429        let result: (i64,) = sqlx::query_as(
430            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
431        )
432        .fetch_one(store.pool())
433        .await
434        .unwrap();
435        assert_eq!(result.0, 1);
436    }
437
438    #[tokio::test]
439    async fn cascade_delete_removes_embeddings_metadata() {
440        let store = test_store().await;
441        let pool = store.pool();
442
443        let cid = store.create_conversation().await.unwrap();
444        let msg_id = store.save_message(cid, "user", "test").await.unwrap();
445
446        let point_id = uuid::Uuid::new_v4().to_string();
447        sqlx::query(
448            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
449             VALUES (?, ?, ?)",
450        )
451        .bind(msg_id)
452        .bind(&point_id)
453        .bind(768_i64)
454        .execute(pool)
455        .await
456        .unwrap();
457
458        let before: (i64,) =
459            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
460                .bind(msg_id)
461                .fetch_one(pool)
462                .await
463                .unwrap();
464        assert_eq!(before.0, 1);
465
466        sqlx::query("DELETE FROM messages WHERE id = ?")
467            .bind(msg_id)
468            .execute(pool)
469            .await
470            .unwrap();
471
472        let after: (i64,) =
473            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
474                .bind(msg_id)
475                .fetch_one(pool)
476                .await
477                .unwrap();
478        assert_eq!(after.0, 0);
479    }
480
481    #[tokio::test]
482    async fn messages_by_ids_batch_fetch() {
483        let store = test_store().await;
484        let cid = store.create_conversation().await.unwrap();
485        let id1 = store.save_message(cid, "user", "hello").await.unwrap();
486        let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
487        let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
488
489        let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
490        assert_eq!(results.len(), 2);
491        assert_eq!(results[0].0, id1);
492        assert_eq!(results[0].1.content, "hello");
493        assert_eq!(results[1].0, id2);
494        assert_eq!(results[1].1.content, "hi");
495    }
496
497    #[tokio::test]
498    async fn messages_by_ids_empty_input() {
499        let store = test_store().await;
500        let results = store.messages_by_ids(&[]).await.unwrap();
501        assert!(results.is_empty());
502    }
503
504    #[tokio::test]
505    async fn messages_by_ids_nonexistent() {
506        let store = test_store().await;
507        let results = store
508            .messages_by_ids(&[MessageId(999), MessageId(1000)])
509            .await
510            .unwrap();
511        assert!(results.is_empty());
512    }
513
514    #[tokio::test]
515    async fn message_by_id_fetches_existing() {
516        let store = test_store().await;
517        let cid = store.create_conversation().await.unwrap();
518        let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
519
520        let msg = store.message_by_id(msg_id).await.unwrap();
521        assert!(msg.is_some());
522        let msg = msg.unwrap();
523        assert_eq!(msg.role, Role::User);
524        assert_eq!(msg.content, "hello");
525    }
526
527    #[tokio::test]
528    async fn message_by_id_returns_none_for_nonexistent() {
529        let store = test_store().await;
530        let msg = store.message_by_id(MessageId(999)).await.unwrap();
531        assert!(msg.is_none());
532    }
533
534    #[tokio::test]
535    async fn unembedded_message_ids_returns_all_when_none_embedded() {
536        let store = test_store().await;
537        let cid = store.create_conversation().await.unwrap();
538
539        store.save_message(cid, "user", "msg1").await.unwrap();
540        store.save_message(cid, "assistant", "msg2").await.unwrap();
541
542        let unembedded = store.unembedded_message_ids(None).await.unwrap();
543        assert_eq!(unembedded.len(), 2);
544        assert_eq!(unembedded[0].3, "msg1");
545        assert_eq!(unembedded[1].3, "msg2");
546    }
547
548    #[tokio::test]
549    async fn unembedded_message_ids_excludes_embedded() {
550        let store = test_store().await;
551        let pool = store.pool();
552        let cid = store.create_conversation().await.unwrap();
553
554        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
555        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
556
557        let point_id = uuid::Uuid::new_v4().to_string();
558        sqlx::query(
559            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
560             VALUES (?, ?, ?)",
561        )
562        .bind(msg_id1)
563        .bind(&point_id)
564        .bind(768_i64)
565        .execute(pool)
566        .await
567        .unwrap();
568
569        let unembedded = store.unembedded_message_ids(None).await.unwrap();
570        assert_eq!(unembedded.len(), 1);
571        assert_eq!(unembedded[0].0, msg_id2);
572        assert_eq!(unembedded[0].3, "msg2");
573    }
574
575    #[tokio::test]
576    async fn unembedded_message_ids_respects_limit() {
577        let store = test_store().await;
578        let cid = store.create_conversation().await.unwrap();
579
580        for i in 0..10 {
581            store
582                .save_message(cid, "user", &format!("msg{i}"))
583                .await
584                .unwrap();
585        }
586
587        let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
588        assert_eq!(unembedded.len(), 3);
589    }
590
591    #[tokio::test]
592    async fn count_messages_returns_correct_count() {
593        let store = test_store().await;
594        let cid = store.create_conversation().await.unwrap();
595
596        assert_eq!(store.count_messages(cid).await.unwrap(), 0);
597
598        store.save_message(cid, "user", "msg1").await.unwrap();
599        store.save_message(cid, "assistant", "msg2").await.unwrap();
600
601        assert_eq!(store.count_messages(cid).await.unwrap(), 2);
602    }
603
604    #[tokio::test]
605    async fn count_messages_after_filters_correctly() {
606        let store = test_store().await;
607        let cid = store.create_conversation().await.unwrap();
608
609        let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
610        let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
611        let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
612
613        assert_eq!(
614            store.count_messages_after(cid, MessageId(0)).await.unwrap(),
615            3
616        );
617        assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
618        assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
619    }
620
621    #[tokio::test]
622    async fn load_messages_range_basic() {
623        let store = test_store().await;
624        let cid = store.create_conversation().await.unwrap();
625
626        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
627        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
628        let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
629
630        let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
631        assert_eq!(msgs.len(), 2);
632        assert_eq!(msgs[0].0, msg_id2);
633        assert_eq!(msgs[0].2, "msg2");
634        assert_eq!(msgs[1].0, msg_id3);
635        assert_eq!(msgs[1].2, "msg3");
636    }
637
638    #[tokio::test]
639    async fn load_messages_range_respects_limit() {
640        let store = test_store().await;
641        let cid = store.create_conversation().await.unwrap();
642
643        store.save_message(cid, "user", "msg1").await.unwrap();
644        store.save_message(cid, "assistant", "msg2").await.unwrap();
645        store.save_message(cid, "user", "msg3").await.unwrap();
646
647        let msgs = store
648            .load_messages_range(cid, MessageId(0), 2)
649            .await
650            .unwrap();
651        assert_eq!(msgs.len(), 2);
652    }
653
654    #[tokio::test]
655    async fn keyword_search_basic() {
656        let store = test_store().await;
657        let cid = store.create_conversation().await.unwrap();
658
659        store
660            .save_message(cid, "user", "rust programming language")
661            .await
662            .unwrap();
663        store
664            .save_message(cid, "assistant", "python is great too")
665            .await
666            .unwrap();
667        store
668            .save_message(cid, "user", "I love rust and cargo")
669            .await
670            .unwrap();
671
672        let results = store.keyword_search("rust", 10, None).await.unwrap();
673        assert_eq!(results.len(), 2);
674        assert!(results.iter().all(|(_, score)| *score > 0.0));
675    }
676
677    #[tokio::test]
678    async fn keyword_search_with_conversation_filter() {
679        let store = test_store().await;
680        let cid1 = store.create_conversation().await.unwrap();
681        let cid2 = store.create_conversation().await.unwrap();
682
683        store
684            .save_message(cid1, "user", "hello world")
685            .await
686            .unwrap();
687        store
688            .save_message(cid2, "user", "hello universe")
689            .await
690            .unwrap();
691
692        let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
693        assert_eq!(results.len(), 1);
694    }
695
696    #[tokio::test]
697    async fn keyword_search_no_match() {
698        let store = test_store().await;
699        let cid = store.create_conversation().await.unwrap();
700
701        store
702            .save_message(cid, "user", "hello world")
703            .await
704            .unwrap();
705
706        let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
707        assert!(results.is_empty());
708    }
709
710    #[tokio::test]
711    async fn keyword_search_respects_limit() {
712        let store = test_store().await;
713        let cid = store.create_conversation().await.unwrap();
714
715        for i in 0..10 {
716            store
717                .save_message(cid, "user", &format!("test message {i}"))
718                .await
719                .unwrap();
720        }
721
722        let results = store.keyword_search("test", 3, None).await.unwrap();
723        assert_eq!(results.len(), 3);
724    }
725}