Skip to main content

zeph_memory/sqlite/
messages.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10fn parse_role(s: &str) -> Role {
11    match s {
12        "assistant" => Role::Assistant,
13        "system" => Role::System,
14        _ => Role::User,
15    }
16}
17
18#[must_use]
19pub fn role_str(role: Role) -> &'static str {
20    match role {
21        Role::System => "system",
22        Role::User => "user",
23        Role::Assistant => "assistant",
24    }
25}
26
27impl SqliteStore {
28    /// Create a new conversation and return its ID.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if the insert fails.
33    pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
34        let row: (ConversationId,) =
35            sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
36                .fetch_one(&self.pool)
37                .await?;
38        Ok(row.0)
39    }
40
41    /// Save a message to the given conversation and return the message ID.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the insert fails.
46    pub async fn save_message(
47        &self,
48        conversation_id: ConversationId,
49        role: &str,
50        content: &str,
51    ) -> Result<MessageId, MemoryError> {
52        self.save_message_with_parts(conversation_id, role, content, "[]")
53            .await
54    }
55
56    /// Save a message with structured parts JSON.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the insert fails.
61    pub async fn save_message_with_parts(
62        &self,
63        conversation_id: ConversationId,
64        role: &str,
65        content: &str,
66        parts_json: &str,
67    ) -> Result<MessageId, MemoryError> {
68        self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
69            .await
70    }
71
72    /// Save a message with visibility metadata.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the insert fails.
77    pub async fn save_message_with_metadata(
78        &self,
79        conversation_id: ConversationId,
80        role: &str,
81        content: &str,
82        parts_json: &str,
83        agent_visible: bool,
84        user_visible: bool,
85    ) -> Result<MessageId, MemoryError> {
86        let row: (MessageId,) = sqlx::query_as(
87            "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
88             VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
89        )
90        .bind(conversation_id)
91        .bind(role)
92        .bind(content)
93        .bind(parts_json)
94        .bind(i64::from(agent_visible))
95        .bind(i64::from(user_visible))
96        .fetch_one(&self.pool)
97        .await?;
98        Ok(row.0)
99    }
100
101    /// Load the most recent messages for a conversation, up to `limit`.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the query fails.
106    pub async fn load_history(
107        &self,
108        conversation_id: ConversationId,
109        limit: u32,
110    ) -> Result<Vec<Message>, MemoryError> {
111        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
112            "SELECT role, content, parts, agent_visible, user_visible FROM (\
113                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
114                WHERE conversation_id = ? \
115                ORDER BY id DESC \
116                LIMIT ?\
117             ) ORDER BY id ASC",
118        )
119        .bind(conversation_id)
120        .bind(limit)
121        .fetch_all(&self.pool)
122        .await?;
123
124        let messages = rows
125            .into_iter()
126            .map(
127                |(role_str, content, parts_json, agent_visible, user_visible)| {
128                    let parts: Vec<MessagePart> =
129                        serde_json::from_str(&parts_json).unwrap_or_default();
130                    Message {
131                        role: parse_role(&role_str),
132                        content,
133                        parts,
134                        metadata: MessageMetadata {
135                            agent_visible: agent_visible != 0,
136                            user_visible: user_visible != 0,
137                            compacted_at: None,
138                        },
139                    }
140                },
141            )
142            .collect();
143        Ok(messages)
144    }
145
146    /// Load messages filtered by visibility flags.
147    ///
148    /// Pass `Some(true)` to filter by a flag, `None` to skip filtering.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if the query fails.
153    pub async fn load_history_filtered(
154        &self,
155        conversation_id: ConversationId,
156        limit: u32,
157        agent_visible: Option<bool>,
158        user_visible: Option<bool>,
159    ) -> Result<Vec<Message>, MemoryError> {
160        let av = agent_visible.map(i64::from);
161        let uv = user_visible.map(i64::from);
162
163        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
164            "SELECT role, content, parts, agent_visible, user_visible FROM (\
165                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
166                WHERE conversation_id = ? \
167                  AND (? IS NULL OR agent_visible = ?) \
168                  AND (? IS NULL OR user_visible = ?) \
169                ORDER BY id DESC \
170                LIMIT ?\
171             ) ORDER BY id ASC",
172        )
173        .bind(conversation_id)
174        .bind(av)
175        .bind(av)
176        .bind(uv)
177        .bind(uv)
178        .bind(limit)
179        .fetch_all(&self.pool)
180        .await?;
181
182        let messages = rows
183            .into_iter()
184            .map(
185                |(role_str, content, parts_json, agent_visible, user_visible)| {
186                    let parts: Vec<MessagePart> =
187                        serde_json::from_str(&parts_json).unwrap_or_default();
188                    Message {
189                        role: parse_role(&role_str),
190                        content,
191                        parts,
192                        metadata: MessageMetadata {
193                            agent_visible: agent_visible != 0,
194                            user_visible: user_visible != 0,
195                            compacted_at: None,
196                        },
197                    }
198                },
199            )
200            .collect();
201        Ok(messages)
202    }
203
204    /// Atomically mark a range of messages as user-only and insert a summary as agent-only.
205    ///
206    /// Within a single transaction:
207    /// 1. Updates `agent_visible=0, compacted_at=now` for messages in `compacted_range`.
208    /// 2. Inserts `summary_content` with `agent_visible=1, user_visible=0`.
209    ///
210    /// Returns the `MessageId` of the inserted summary.
211    ///
212    /// # Errors
213    ///
214    /// Returns an error if the transaction fails.
215    pub async fn replace_conversation(
216        &self,
217        conversation_id: ConversationId,
218        compacted_range: std::ops::RangeInclusive<MessageId>,
219        summary_role: &str,
220        summary_content: &str,
221    ) -> Result<MessageId, MemoryError> {
222        let now = {
223            let secs = std::time::SystemTime::now()
224                .duration_since(std::time::UNIX_EPOCH)
225                .unwrap_or_default()
226                .as_secs();
227            format!("{secs}")
228        };
229        let start_id = compacted_range.start().0;
230        let end_id = compacted_range.end().0;
231
232        let mut tx = self.pool.begin().await?;
233
234        sqlx::query(
235            "UPDATE messages SET agent_visible = 0, compacted_at = ? \
236             WHERE conversation_id = ? AND id >= ? AND id <= ?",
237        )
238        .bind(&now)
239        .bind(conversation_id)
240        .bind(start_id)
241        .bind(end_id)
242        .execute(&mut *tx)
243        .await?;
244
245        let row: (MessageId,) = sqlx::query_as(
246            "INSERT INTO messages \
247             (conversation_id, role, content, parts, agent_visible, user_visible) \
248             VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
249        )
250        .bind(conversation_id)
251        .bind(summary_role)
252        .bind(summary_content)
253        .fetch_one(&mut *tx)
254        .await?;
255
256        tx.commit().await?;
257
258        Ok(row.0)
259    }
260
261    /// Return the IDs of the N oldest messages in a conversation (ascending order).
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the query fails.
266    pub async fn oldest_message_ids(
267        &self,
268        conversation_id: ConversationId,
269        n: u32,
270    ) -> Result<Vec<MessageId>, MemoryError> {
271        let rows: Vec<(MessageId,)> = sqlx::query_as(
272            "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
273        )
274        .bind(conversation_id)
275        .bind(n)
276        .fetch_all(&self.pool)
277        .await?;
278        Ok(rows.into_iter().map(|r| r.0).collect())
279    }
280
281    /// Return the ID of the most recent conversation, if any.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if the query fails.
286    pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
287        let row: Option<(ConversationId,)> =
288            sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
289                .fetch_optional(&self.pool)
290                .await?;
291        Ok(row.map(|r| r.0))
292    }
293
294    /// Fetch a single message by its ID.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the query fails.
299    pub async fn message_by_id(
300        &self,
301        message_id: MessageId,
302    ) -> Result<Option<Message>, MemoryError> {
303        let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
304            "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
305        )
306        .bind(message_id)
307        .fetch_optional(&self.pool)
308        .await?;
309
310        Ok(row.map(
311            |(role_str, content, parts_json, agent_visible, user_visible)| {
312                let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
313                Message {
314                    role: parse_role(&role_str),
315                    content,
316                    parts,
317                    metadata: MessageMetadata {
318                        agent_visible: agent_visible != 0,
319                        user_visible: user_visible != 0,
320                        compacted_at: None,
321                    },
322                }
323            },
324        ))
325    }
326
327    /// Fetch messages by a list of IDs in a single query.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if the query fails.
332    pub async fn messages_by_ids(
333        &self,
334        ids: &[MessageId],
335    ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
336        if ids.is_empty() {
337            return Ok(Vec::new());
338        }
339
340        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
341
342        let query = format!(
343            "SELECT id, role, content, parts FROM messages \
344             WHERE id IN ({placeholders}) AND agent_visible = 1"
345        );
346        let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
347        for &id in ids {
348            q = q.bind(id);
349        }
350
351        let rows = q.fetch_all(&self.pool).await?;
352
353        Ok(rows
354            .into_iter()
355            .map(|(id, role_str, content, parts_json)| {
356                let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
357                (
358                    id,
359                    Message {
360                        role: parse_role(&role_str),
361                        content,
362                        parts,
363                        metadata: MessageMetadata::default(),
364                    },
365                )
366            })
367            .collect())
368    }
369
370    /// Return message IDs and content for messages without embeddings.
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if the query fails.
375    pub async fn unembedded_message_ids(
376        &self,
377        limit: Option<usize>,
378    ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
379        let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
380
381        let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
382            "SELECT m.id, m.conversation_id, m.role, m.content \
383             FROM messages m \
384             LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
385             WHERE em.id IS NULL \
386             ORDER BY m.id ASC \
387             LIMIT ?",
388        )
389        .bind(effective_limit)
390        .fetch_all(&self.pool)
391        .await?;
392
393        Ok(rows)
394    }
395
396    /// Count the number of messages in a conversation.
397    ///
398    /// # Errors
399    ///
400    /// Returns an error if the query fails.
401    pub async fn count_messages(
402        &self,
403        conversation_id: ConversationId,
404    ) -> Result<i64, MemoryError> {
405        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
406            .bind(conversation_id)
407            .fetch_one(&self.pool)
408            .await?;
409        Ok(row.0)
410    }
411
412    /// Count messages in a conversation with id greater than `after_id`.
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if the query fails.
417    pub async fn count_messages_after(
418        &self,
419        conversation_id: ConversationId,
420        after_id: MessageId,
421    ) -> Result<i64, MemoryError> {
422        let row: (i64,) =
423            sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
424                .bind(conversation_id)
425                .bind(after_id)
426                .fetch_one(&self.pool)
427                .await?;
428        Ok(row.0)
429    }
430
431    /// Full-text keyword search over messages using FTS5.
432    ///
433    /// Returns message IDs with BM25 relevance scores (lower = more relevant,
434    /// negated to positive for consistency with vector scores).
435    ///
436    /// # Errors
437    ///
438    /// Returns an error if the query fails.
439    pub async fn keyword_search(
440        &self,
441        query: &str,
442        limit: usize,
443        conversation_id: Option<ConversationId>,
444    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
445        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
446
447        let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
448            sqlx::query_as(
449                "SELECT m.id, -rank AS score \
450                 FROM messages_fts f \
451                 JOIN messages m ON m.id = f.rowid \
452                 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
453                 ORDER BY rank \
454                 LIMIT ?",
455            )
456            .bind(query)
457            .bind(cid)
458            .bind(effective_limit)
459            .fetch_all(&self.pool)
460            .await?
461        } else {
462            sqlx::query_as(
463                "SELECT m.id, -rank AS score \
464                 FROM messages_fts f \
465                 JOIN messages m ON m.id = f.rowid \
466                 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
467                 ORDER BY rank \
468                 LIMIT ?",
469            )
470            .bind(query)
471            .bind(effective_limit)
472            .fetch_all(&self.pool)
473            .await?
474        };
475
476        Ok(rows)
477    }
478
479    /// Fetch creation timestamps (Unix epoch seconds) for the given message IDs.
480    ///
481    /// Messages without a `created_at` column fall back to 0.
482    ///
483    /// # Errors
484    ///
485    /// Returns an error if the query fails.
486    pub async fn message_timestamps(
487        &self,
488        ids: &[MessageId],
489    ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
490        if ids.is_empty() {
491            return Ok(std::collections::HashMap::new());
492        }
493
494        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
495        let query = format!(
496            "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
497             FROM messages WHERE id IN ({placeholders})"
498        );
499        let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
500        for &id in ids {
501            q = q.bind(id);
502        }
503
504        let rows = q.fetch_all(&self.pool).await?;
505        Ok(rows.into_iter().collect())
506    }
507
508    /// Load a range of messages after a given message ID.
509    ///
510    /// # Errors
511    ///
512    /// Returns an error if the query fails.
513    pub async fn load_messages_range(
514        &self,
515        conversation_id: ConversationId,
516        after_message_id: MessageId,
517        limit: usize,
518    ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
519        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
520
521        let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
522            "SELECT id, role, content FROM messages \
523             WHERE conversation_id = ? AND id > ? \
524             ORDER BY id ASC LIMIT ?",
525        )
526        .bind(conversation_id)
527        .bind(after_message_id)
528        .bind(effective_limit)
529        .fetch_all(&self.pool)
530        .await?;
531
532        Ok(rows)
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    async fn test_store() -> SqliteStore {
541        SqliteStore::new(":memory:").await.unwrap()
542    }
543
544    #[tokio::test]
545    async fn create_conversation_returns_id() {
546        let store = test_store().await;
547        let id1 = store.create_conversation().await.unwrap();
548        let id2 = store.create_conversation().await.unwrap();
549        assert_eq!(id1, ConversationId(1));
550        assert_eq!(id2, ConversationId(2));
551    }
552
553    #[tokio::test]
554    async fn save_and_load_messages() {
555        let store = test_store().await;
556        let cid = store.create_conversation().await.unwrap();
557
558        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
559        let msg_id2 = store
560            .save_message(cid, "assistant", "hi there")
561            .await
562            .unwrap();
563
564        assert_eq!(msg_id1, MessageId(1));
565        assert_eq!(msg_id2, MessageId(2));
566
567        let history = store.load_history(cid, 50).await.unwrap();
568        assert_eq!(history.len(), 2);
569        assert_eq!(history[0].role, Role::User);
570        assert_eq!(history[0].content, "hello");
571        assert_eq!(history[1].role, Role::Assistant);
572        assert_eq!(history[1].content, "hi there");
573    }
574
575    #[tokio::test]
576    async fn load_history_respects_limit() {
577        let store = test_store().await;
578        let cid = store.create_conversation().await.unwrap();
579
580        for i in 0..10 {
581            store
582                .save_message(cid, "user", &format!("msg {i}"))
583                .await
584                .unwrap();
585        }
586
587        let history = store.load_history(cid, 3).await.unwrap();
588        assert_eq!(history.len(), 3);
589        assert_eq!(history[0].content, "msg 7");
590        assert_eq!(history[1].content, "msg 8");
591        assert_eq!(history[2].content, "msg 9");
592    }
593
594    #[tokio::test]
595    async fn latest_conversation_id_empty() {
596        let store = test_store().await;
597        assert!(store.latest_conversation_id().await.unwrap().is_none());
598    }
599
600    #[tokio::test]
601    async fn latest_conversation_id_returns_newest() {
602        let store = test_store().await;
603        store.create_conversation().await.unwrap();
604        let id2 = store.create_conversation().await.unwrap();
605        assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
606    }
607
608    #[tokio::test]
609    async fn messages_isolated_per_conversation() {
610        let store = test_store().await;
611        let cid1 = store.create_conversation().await.unwrap();
612        let cid2 = store.create_conversation().await.unwrap();
613
614        store.save_message(cid1, "user", "conv1").await.unwrap();
615        store.save_message(cid2, "user", "conv2").await.unwrap();
616
617        let h1 = store.load_history(cid1, 50).await.unwrap();
618        let h2 = store.load_history(cid2, 50).await.unwrap();
619        assert_eq!(h1.len(), 1);
620        assert_eq!(h1[0].content, "conv1");
621        assert_eq!(h2.len(), 1);
622        assert_eq!(h2[0].content, "conv2");
623    }
624
625    #[tokio::test]
626    async fn pool_accessor_returns_valid_pool() {
627        let store = test_store().await;
628        let pool = store.pool();
629        let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
630        assert_eq!(row.0, 1);
631    }
632
633    #[tokio::test]
634    async fn embeddings_metadata_table_exists() {
635        let store = test_store().await;
636        let result: (i64,) = sqlx::query_as(
637            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
638        )
639        .fetch_one(store.pool())
640        .await
641        .unwrap();
642        assert_eq!(result.0, 1);
643    }
644
645    #[tokio::test]
646    async fn cascade_delete_removes_embeddings_metadata() {
647        let store = test_store().await;
648        let pool = store.pool();
649
650        let cid = store.create_conversation().await.unwrap();
651        let msg_id = store.save_message(cid, "user", "test").await.unwrap();
652
653        let point_id = uuid::Uuid::new_v4().to_string();
654        sqlx::query(
655            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
656             VALUES (?, ?, ?)",
657        )
658        .bind(msg_id)
659        .bind(&point_id)
660        .bind(768_i64)
661        .execute(pool)
662        .await
663        .unwrap();
664
665        let before: (i64,) =
666            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
667                .bind(msg_id)
668                .fetch_one(pool)
669                .await
670                .unwrap();
671        assert_eq!(before.0, 1);
672
673        sqlx::query("DELETE FROM messages WHERE id = ?")
674            .bind(msg_id)
675            .execute(pool)
676            .await
677            .unwrap();
678
679        let after: (i64,) =
680            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
681                .bind(msg_id)
682                .fetch_one(pool)
683                .await
684                .unwrap();
685        assert_eq!(after.0, 0);
686    }
687
688    #[tokio::test]
689    async fn messages_by_ids_batch_fetch() {
690        let store = test_store().await;
691        let cid = store.create_conversation().await.unwrap();
692        let id1 = store.save_message(cid, "user", "hello").await.unwrap();
693        let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
694        let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
695
696        let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
697        assert_eq!(results.len(), 2);
698        assert_eq!(results[0].0, id1);
699        assert_eq!(results[0].1.content, "hello");
700        assert_eq!(results[1].0, id2);
701        assert_eq!(results[1].1.content, "hi");
702    }
703
704    #[tokio::test]
705    async fn messages_by_ids_empty_input() {
706        let store = test_store().await;
707        let results = store.messages_by_ids(&[]).await.unwrap();
708        assert!(results.is_empty());
709    }
710
711    #[tokio::test]
712    async fn messages_by_ids_nonexistent() {
713        let store = test_store().await;
714        let results = store
715            .messages_by_ids(&[MessageId(999), MessageId(1000)])
716            .await
717            .unwrap();
718        assert!(results.is_empty());
719    }
720
721    #[tokio::test]
722    async fn message_by_id_fetches_existing() {
723        let store = test_store().await;
724        let cid = store.create_conversation().await.unwrap();
725        let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
726
727        let msg = store.message_by_id(msg_id).await.unwrap();
728        assert!(msg.is_some());
729        let msg = msg.unwrap();
730        assert_eq!(msg.role, Role::User);
731        assert_eq!(msg.content, "hello");
732    }
733
734    #[tokio::test]
735    async fn message_by_id_returns_none_for_nonexistent() {
736        let store = test_store().await;
737        let msg = store.message_by_id(MessageId(999)).await.unwrap();
738        assert!(msg.is_none());
739    }
740
741    #[tokio::test]
742    async fn unembedded_message_ids_returns_all_when_none_embedded() {
743        let store = test_store().await;
744        let cid = store.create_conversation().await.unwrap();
745
746        store.save_message(cid, "user", "msg1").await.unwrap();
747        store.save_message(cid, "assistant", "msg2").await.unwrap();
748
749        let unembedded = store.unembedded_message_ids(None).await.unwrap();
750        assert_eq!(unembedded.len(), 2);
751        assert_eq!(unembedded[0].3, "msg1");
752        assert_eq!(unembedded[1].3, "msg2");
753    }
754
755    #[tokio::test]
756    async fn unembedded_message_ids_excludes_embedded() {
757        let store = test_store().await;
758        let pool = store.pool();
759        let cid = store.create_conversation().await.unwrap();
760
761        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
762        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
763
764        let point_id = uuid::Uuid::new_v4().to_string();
765        sqlx::query(
766            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
767             VALUES (?, ?, ?)",
768        )
769        .bind(msg_id1)
770        .bind(&point_id)
771        .bind(768_i64)
772        .execute(pool)
773        .await
774        .unwrap();
775
776        let unembedded = store.unembedded_message_ids(None).await.unwrap();
777        assert_eq!(unembedded.len(), 1);
778        assert_eq!(unembedded[0].0, msg_id2);
779        assert_eq!(unembedded[0].3, "msg2");
780    }
781
782    #[tokio::test]
783    async fn unembedded_message_ids_respects_limit() {
784        let store = test_store().await;
785        let cid = store.create_conversation().await.unwrap();
786
787        for i in 0..10 {
788            store
789                .save_message(cid, "user", &format!("msg{i}"))
790                .await
791                .unwrap();
792        }
793
794        let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
795        assert_eq!(unembedded.len(), 3);
796    }
797
798    #[tokio::test]
799    async fn count_messages_returns_correct_count() {
800        let store = test_store().await;
801        let cid = store.create_conversation().await.unwrap();
802
803        assert_eq!(store.count_messages(cid).await.unwrap(), 0);
804
805        store.save_message(cid, "user", "msg1").await.unwrap();
806        store.save_message(cid, "assistant", "msg2").await.unwrap();
807
808        assert_eq!(store.count_messages(cid).await.unwrap(), 2);
809    }
810
811    #[tokio::test]
812    async fn count_messages_after_filters_correctly() {
813        let store = test_store().await;
814        let cid = store.create_conversation().await.unwrap();
815
816        let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
817        let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
818        let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
819
820        assert_eq!(
821            store.count_messages_after(cid, MessageId(0)).await.unwrap(),
822            3
823        );
824        assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
825        assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
826    }
827
828    #[tokio::test]
829    async fn load_messages_range_basic() {
830        let store = test_store().await;
831        let cid = store.create_conversation().await.unwrap();
832
833        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
834        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
835        let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
836
837        let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
838        assert_eq!(msgs.len(), 2);
839        assert_eq!(msgs[0].0, msg_id2);
840        assert_eq!(msgs[0].2, "msg2");
841        assert_eq!(msgs[1].0, msg_id3);
842        assert_eq!(msgs[1].2, "msg3");
843    }
844
845    #[tokio::test]
846    async fn load_messages_range_respects_limit() {
847        let store = test_store().await;
848        let cid = store.create_conversation().await.unwrap();
849
850        store.save_message(cid, "user", "msg1").await.unwrap();
851        store.save_message(cid, "assistant", "msg2").await.unwrap();
852        store.save_message(cid, "user", "msg3").await.unwrap();
853
854        let msgs = store
855            .load_messages_range(cid, MessageId(0), 2)
856            .await
857            .unwrap();
858        assert_eq!(msgs.len(), 2);
859    }
860
861    #[tokio::test]
862    async fn keyword_search_basic() {
863        let store = test_store().await;
864        let cid = store.create_conversation().await.unwrap();
865
866        store
867            .save_message(cid, "user", "rust programming language")
868            .await
869            .unwrap();
870        store
871            .save_message(cid, "assistant", "python is great too")
872            .await
873            .unwrap();
874        store
875            .save_message(cid, "user", "I love rust and cargo")
876            .await
877            .unwrap();
878
879        let results = store.keyword_search("rust", 10, None).await.unwrap();
880        assert_eq!(results.len(), 2);
881        assert!(results.iter().all(|(_, score)| *score > 0.0));
882    }
883
884    #[tokio::test]
885    async fn keyword_search_with_conversation_filter() {
886        let store = test_store().await;
887        let cid1 = store.create_conversation().await.unwrap();
888        let cid2 = store.create_conversation().await.unwrap();
889
890        store
891            .save_message(cid1, "user", "hello world")
892            .await
893            .unwrap();
894        store
895            .save_message(cid2, "user", "hello universe")
896            .await
897            .unwrap();
898
899        let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
900        assert_eq!(results.len(), 1);
901    }
902
903    #[tokio::test]
904    async fn keyword_search_no_match() {
905        let store = test_store().await;
906        let cid = store.create_conversation().await.unwrap();
907
908        store
909            .save_message(cid, "user", "hello world")
910            .await
911            .unwrap();
912
913        let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
914        assert!(results.is_empty());
915    }
916
917    #[tokio::test]
918    async fn keyword_search_respects_limit() {
919        let store = test_store().await;
920        let cid = store.create_conversation().await.unwrap();
921
922        for i in 0..10 {
923            store
924                .save_message(cid, "user", &format!("test message {i}"))
925                .await
926                .unwrap();
927        }
928
929        let results = store.keyword_search("test", 3, None).await.unwrap();
930        assert_eq!(results.len(), 3);
931    }
932
933    #[tokio::test]
934    async fn save_message_with_metadata_stores_visibility() {
935        let store = test_store().await;
936        let cid = store.create_conversation().await.unwrap();
937
938        let id = store
939            .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
940            .await
941            .unwrap();
942
943        let history = store.load_history(cid, 10).await.unwrap();
944        assert_eq!(history.len(), 1);
945        assert!(!history[0].metadata.agent_visible);
946        assert!(history[0].metadata.user_visible);
947        assert_eq!(id, MessageId(1));
948    }
949
950    #[tokio::test]
951    async fn load_history_filtered_by_agent_visible() {
952        let store = test_store().await;
953        let cid = store.create_conversation().await.unwrap();
954
955        store
956            .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
957            .await
958            .unwrap();
959        store
960            .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
961            .await
962            .unwrap();
963
964        let agent_msgs = store
965            .load_history_filtered(cid, 50, Some(true), None)
966            .await
967            .unwrap();
968        assert_eq!(agent_msgs.len(), 1);
969        assert_eq!(agent_msgs[0].content, "visible to agent");
970    }
971
972    #[tokio::test]
973    async fn load_history_filtered_by_user_visible() {
974        let store = test_store().await;
975        let cid = store.create_conversation().await.unwrap();
976
977        store
978            .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
979            .await
980            .unwrap();
981        store
982            .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
983            .await
984            .unwrap();
985
986        let user_msgs = store
987            .load_history_filtered(cid, 50, None, Some(true))
988            .await
989            .unwrap();
990        assert_eq!(user_msgs.len(), 1);
991        assert_eq!(user_msgs[0].content, "user sees this");
992    }
993
994    #[tokio::test]
995    async fn load_history_filtered_no_filter_returns_all() {
996        let store = test_store().await;
997        let cid = store.create_conversation().await.unwrap();
998
999        store
1000            .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1001            .await
1002            .unwrap();
1003        store
1004            .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1005            .await
1006            .unwrap();
1007
1008        let all_msgs = store
1009            .load_history_filtered(cid, 50, None, None)
1010            .await
1011            .unwrap();
1012        assert_eq!(all_msgs.len(), 2);
1013    }
1014
1015    #[tokio::test]
1016    async fn replace_conversation_marks_originals_and_inserts_summary() {
1017        let store = test_store().await;
1018        let cid = store.create_conversation().await.unwrap();
1019
1020        let id1 = store.save_message(cid, "user", "first").await.unwrap();
1021        let id2 = store
1022            .save_message(cid, "assistant", "second")
1023            .await
1024            .unwrap();
1025        let id3 = store.save_message(cid, "user", "third").await.unwrap();
1026
1027        let summary_id = store
1028            .replace_conversation(cid, id1..=id2, "system", "summary text")
1029            .await
1030            .unwrap();
1031
1032        // Original messages should be user_only
1033        let all = store.load_history(cid, 50).await.unwrap();
1034        // id1 and id2 marked agent_visible=false, id3 untouched, summary inserted
1035        let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1036        assert!(!by_id1.metadata.agent_visible);
1037        assert!(by_id1.metadata.user_visible);
1038
1039        let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1040        assert!(!by_id2.metadata.agent_visible);
1041
1042        let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1043        assert!(by_id3.metadata.agent_visible);
1044
1045        // Summary is agent_only (agent_visible=1, user_visible=0)
1046        let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1047        assert!(summary.metadata.agent_visible);
1048        assert!(!summary.metadata.user_visible);
1049        assert!(summary_id > id3);
1050    }
1051
1052    #[tokio::test]
1053    async fn oldest_message_ids_returns_in_order() {
1054        let store = test_store().await;
1055        let cid = store.create_conversation().await.unwrap();
1056
1057        let id1 = store.save_message(cid, "user", "a").await.unwrap();
1058        let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1059        let id3 = store.save_message(cid, "user", "c").await.unwrap();
1060
1061        let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1062        assert_eq!(ids, vec![id1, id2]);
1063        assert!(ids[0] < ids[1]);
1064
1065        let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1066        assert_eq!(all_ids, vec![id1, id2, id3]);
1067    }
1068
1069    #[tokio::test]
1070    async fn message_metadata_default_both_visible() {
1071        let store = test_store().await;
1072        let cid = store.create_conversation().await.unwrap();
1073
1074        store.save_message(cid, "user", "normal").await.unwrap();
1075
1076        let history = store.load_history(cid, 10).await.unwrap();
1077        assert!(history[0].metadata.agent_visible);
1078        assert!(history[0].metadata.user_visible);
1079        assert!(history[0].metadata.compacted_at.is_none());
1080    }
1081}