Skip to main content

zeph_memory/sqlite/
messages.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10fn parse_role(s: &str) -> Role {
11    match s {
12        "assistant" => Role::Assistant,
13        "system" => Role::System,
14        _ => Role::User,
15    }
16}
17
18#[must_use]
19pub fn role_str(role: Role) -> &'static str {
20    match role {
21        Role::System => "system",
22        Role::User => "user",
23        Role::Assistant => "assistant",
24    }
25}
26
27impl SqliteStore {
28    /// Create a new conversation and return its ID.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if the insert fails.
33    pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
34        let row: (ConversationId,) =
35            sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
36                .fetch_one(&self.pool)
37                .await?;
38        Ok(row.0)
39    }
40
41    /// Save a message to the given conversation and return the message ID.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the insert fails.
46    pub async fn save_message(
47        &self,
48        conversation_id: ConversationId,
49        role: &str,
50        content: &str,
51    ) -> Result<MessageId, MemoryError> {
52        self.save_message_with_parts(conversation_id, role, content, "[]")
53            .await
54    }
55
56    /// Save a message with structured parts JSON.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the insert fails.
61    pub async fn save_message_with_parts(
62        &self,
63        conversation_id: ConversationId,
64        role: &str,
65        content: &str,
66        parts_json: &str,
67    ) -> Result<MessageId, MemoryError> {
68        self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
69            .await
70    }
71
72    /// Save a message with visibility metadata.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the insert fails.
77    pub async fn save_message_with_metadata(
78        &self,
79        conversation_id: ConversationId,
80        role: &str,
81        content: &str,
82        parts_json: &str,
83        agent_visible: bool,
84        user_visible: bool,
85    ) -> Result<MessageId, MemoryError> {
86        let row: (MessageId,) = sqlx::query_as(
87            "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
88             VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
89        )
90        .bind(conversation_id)
91        .bind(role)
92        .bind(content)
93        .bind(parts_json)
94        .bind(i64::from(agent_visible))
95        .bind(i64::from(user_visible))
96        .fetch_one(&self.pool)
97        .await?;
98        Ok(row.0)
99    }
100
101    /// Load the most recent messages for a conversation, up to `limit`.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the query fails.
106    pub async fn load_history(
107        &self,
108        conversation_id: ConversationId,
109        limit: u32,
110    ) -> Result<Vec<Message>, MemoryError> {
111        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
112            "SELECT role, content, parts, agent_visible, user_visible FROM (\
113                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
114                WHERE conversation_id = ? \
115                ORDER BY id DESC \
116                LIMIT ?\
117             ) ORDER BY id ASC",
118        )
119        .bind(conversation_id)
120        .bind(limit)
121        .fetch_all(&self.pool)
122        .await?;
123
124        let messages = rows
125            .into_iter()
126            .map(
127                |(role_str, content, parts_json, agent_visible, user_visible)| {
128                    let parts: Vec<MessagePart> = if parts_json == "[]" {
129                        vec![]
130                    } else {
131                        serde_json::from_str(&parts_json).unwrap_or_default()
132                    };
133                    Message {
134                        role: parse_role(&role_str),
135                        content,
136                        parts,
137                        metadata: MessageMetadata {
138                            agent_visible: agent_visible != 0,
139                            user_visible: user_visible != 0,
140                            compacted_at: None,
141                        },
142                    }
143                },
144            )
145            .collect();
146        Ok(messages)
147    }
148
149    /// Load messages filtered by visibility flags.
150    ///
151    /// Pass `Some(true)` to filter by a flag, `None` to skip filtering.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if the query fails.
156    pub async fn load_history_filtered(
157        &self,
158        conversation_id: ConversationId,
159        limit: u32,
160        agent_visible: Option<bool>,
161        user_visible: Option<bool>,
162    ) -> Result<Vec<Message>, MemoryError> {
163        let av = agent_visible.map(i64::from);
164        let uv = user_visible.map(i64::from);
165
166        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
167            "WITH recent AS (\
168                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
169                WHERE conversation_id = ? \
170                  AND (? IS NULL OR agent_visible = ?) \
171                  AND (? IS NULL OR user_visible = ?) \
172                ORDER BY id DESC \
173                LIMIT ?\
174             ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
175        )
176        .bind(conversation_id)
177        .bind(av)
178        .bind(av)
179        .bind(uv)
180        .bind(uv)
181        .bind(limit)
182        .fetch_all(&self.pool)
183        .await?;
184
185        let messages = rows
186            .into_iter()
187            .map(
188                |(role_str, content, parts_json, agent_visible, user_visible)| {
189                    let parts: Vec<MessagePart> = if parts_json == "[]" {
190                        vec![]
191                    } else {
192                        serde_json::from_str(&parts_json).unwrap_or_default()
193                    };
194                    Message {
195                        role: parse_role(&role_str),
196                        content,
197                        parts,
198                        metadata: MessageMetadata {
199                            agent_visible: agent_visible != 0,
200                            user_visible: user_visible != 0,
201                            compacted_at: None,
202                        },
203                    }
204                },
205            )
206            .collect();
207        Ok(messages)
208    }
209
210    /// Atomically mark a range of messages as user-only and insert a summary as agent-only.
211    ///
212    /// Within a single transaction:
213    /// 1. Updates `agent_visible=0, compacted_at=now` for messages in `compacted_range`.
214    /// 2. Inserts `summary_content` with `agent_visible=1, user_visible=0`.
215    ///
216    /// Returns the `MessageId` of the inserted summary.
217    ///
218    /// # Errors
219    ///
220    /// Returns an error if the transaction fails.
221    pub async fn replace_conversation(
222        &self,
223        conversation_id: ConversationId,
224        compacted_range: std::ops::RangeInclusive<MessageId>,
225        summary_role: &str,
226        summary_content: &str,
227    ) -> Result<MessageId, MemoryError> {
228        let now = {
229            let secs = std::time::SystemTime::now()
230                .duration_since(std::time::UNIX_EPOCH)
231                .unwrap_or_default()
232                .as_secs();
233            format!("{secs}")
234        };
235        let start_id = compacted_range.start().0;
236        let end_id = compacted_range.end().0;
237
238        let mut tx = self.pool.begin().await?;
239
240        sqlx::query(
241            "UPDATE messages SET agent_visible = 0, compacted_at = ? \
242             WHERE conversation_id = ? AND id >= ? AND id <= ?",
243        )
244        .bind(&now)
245        .bind(conversation_id)
246        .bind(start_id)
247        .bind(end_id)
248        .execute(&mut *tx)
249        .await?;
250
251        let row: (MessageId,) = sqlx::query_as(
252            "INSERT INTO messages \
253             (conversation_id, role, content, parts, agent_visible, user_visible) \
254             VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
255        )
256        .bind(conversation_id)
257        .bind(summary_role)
258        .bind(summary_content)
259        .fetch_one(&mut *tx)
260        .await?;
261
262        tx.commit().await?;
263
264        Ok(row.0)
265    }
266
267    /// Return the IDs of the N oldest messages in a conversation (ascending order).
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if the query fails.
272    pub async fn oldest_message_ids(
273        &self,
274        conversation_id: ConversationId,
275        n: u32,
276    ) -> Result<Vec<MessageId>, MemoryError> {
277        let rows: Vec<(MessageId,)> = sqlx::query_as(
278            "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
279        )
280        .bind(conversation_id)
281        .bind(n)
282        .fetch_all(&self.pool)
283        .await?;
284        Ok(rows.into_iter().map(|r| r.0).collect())
285    }
286
287    /// Return the ID of the most recent conversation, if any.
288    ///
289    /// # Errors
290    ///
291    /// Returns an error if the query fails.
292    pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
293        let row: Option<(ConversationId,)> =
294            sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
295                .fetch_optional(&self.pool)
296                .await?;
297        Ok(row.map(|r| r.0))
298    }
299
300    /// Fetch a single message by its ID.
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if the query fails.
305    pub async fn message_by_id(
306        &self,
307        message_id: MessageId,
308    ) -> Result<Option<Message>, MemoryError> {
309        let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
310            "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
311        )
312        .bind(message_id)
313        .fetch_optional(&self.pool)
314        .await?;
315
316        Ok(row.map(
317            |(role_str, content, parts_json, agent_visible, user_visible)| {
318                let parts: Vec<MessagePart> = if parts_json == "[]" {
319                    vec![]
320                } else {
321                    serde_json::from_str(&parts_json).unwrap_or_default()
322                };
323                Message {
324                    role: parse_role(&role_str),
325                    content,
326                    parts,
327                    metadata: MessageMetadata {
328                        agent_visible: agent_visible != 0,
329                        user_visible: user_visible != 0,
330                        compacted_at: None,
331                    },
332                }
333            },
334        ))
335    }
336
337    /// Fetch messages by a list of IDs in a single query.
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if the query fails.
342    pub async fn messages_by_ids(
343        &self,
344        ids: &[MessageId],
345    ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
346        if ids.is_empty() {
347            return Ok(Vec::new());
348        }
349
350        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
351
352        let query = format!(
353            "SELECT id, role, content, parts FROM messages \
354             WHERE id IN ({placeholders}) AND agent_visible = 1"
355        );
356        let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
357        for &id in ids {
358            q = q.bind(id);
359        }
360
361        let rows = q.fetch_all(&self.pool).await?;
362
363        Ok(rows
364            .into_iter()
365            .map(|(id, role_str, content, parts_json)| {
366                let parts: Vec<MessagePart> = if parts_json == "[]" {
367                    vec![]
368                } else {
369                    serde_json::from_str(&parts_json).unwrap_or_default()
370                };
371                (
372                    id,
373                    Message {
374                        role: parse_role(&role_str),
375                        content,
376                        parts,
377                        metadata: MessageMetadata::default(),
378                    },
379                )
380            })
381            .collect())
382    }
383
384    /// Return message IDs and content for messages without embeddings.
385    ///
386    /// # Errors
387    ///
388    /// Returns an error if the query fails.
389    pub async fn unembedded_message_ids(
390        &self,
391        limit: Option<usize>,
392    ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
393        let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
394
395        let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
396            "SELECT m.id, m.conversation_id, m.role, m.content \
397             FROM messages m \
398             LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
399             WHERE em.id IS NULL \
400             ORDER BY m.id ASC \
401             LIMIT ?",
402        )
403        .bind(effective_limit)
404        .fetch_all(&self.pool)
405        .await?;
406
407        Ok(rows)
408    }
409
410    /// Count the number of messages in a conversation.
411    ///
412    /// # Errors
413    ///
414    /// Returns an error if the query fails.
415    pub async fn count_messages(
416        &self,
417        conversation_id: ConversationId,
418    ) -> Result<i64, MemoryError> {
419        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
420            .bind(conversation_id)
421            .fetch_one(&self.pool)
422            .await?;
423        Ok(row.0)
424    }
425
426    /// Count messages in a conversation with id greater than `after_id`.
427    ///
428    /// # Errors
429    ///
430    /// Returns an error if the query fails.
431    pub async fn count_messages_after(
432        &self,
433        conversation_id: ConversationId,
434        after_id: MessageId,
435    ) -> Result<i64, MemoryError> {
436        let row: (i64,) =
437            sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
438                .bind(conversation_id)
439                .bind(after_id)
440                .fetch_one(&self.pool)
441                .await?;
442        Ok(row.0)
443    }
444
445    /// Full-text keyword search over messages using FTS5.
446    ///
447    /// Returns message IDs with BM25 relevance scores (lower = more relevant,
448    /// negated to positive for consistency with vector scores).
449    ///
450    /// # Errors
451    ///
452    /// Returns an error if the query fails.
453    pub async fn keyword_search(
454        &self,
455        query: &str,
456        limit: usize,
457        conversation_id: Option<ConversationId>,
458    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
459        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
460
461        let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
462            sqlx::query_as(
463                "SELECT m.id, -rank AS score \
464                 FROM messages_fts f \
465                 JOIN messages m ON m.id = f.rowid \
466                 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
467                 ORDER BY rank \
468                 LIMIT ?",
469            )
470            .bind(query)
471            .bind(cid)
472            .bind(effective_limit)
473            .fetch_all(&self.pool)
474            .await?
475        } else {
476            sqlx::query_as(
477                "SELECT m.id, -rank AS score \
478                 FROM messages_fts f \
479                 JOIN messages m ON m.id = f.rowid \
480                 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
481                 ORDER BY rank \
482                 LIMIT ?",
483            )
484            .bind(query)
485            .bind(effective_limit)
486            .fetch_all(&self.pool)
487            .await?
488        };
489
490        Ok(rows)
491    }
492
493    /// Fetch creation timestamps (Unix epoch seconds) for the given message IDs.
494    ///
495    /// Messages without a `created_at` column fall back to 0.
496    ///
497    /// # Errors
498    ///
499    /// Returns an error if the query fails.
500    pub async fn message_timestamps(
501        &self,
502        ids: &[MessageId],
503    ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
504        if ids.is_empty() {
505            return Ok(std::collections::HashMap::new());
506        }
507
508        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
509        let query = format!(
510            "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
511             FROM messages WHERE id IN ({placeholders})"
512        );
513        let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
514        for &id in ids {
515            q = q.bind(id);
516        }
517
518        let rows = q.fetch_all(&self.pool).await?;
519        Ok(rows.into_iter().collect())
520    }
521
522    /// Load a range of messages after a given message ID.
523    ///
524    /// # Errors
525    ///
526    /// Returns an error if the query fails.
527    pub async fn load_messages_range(
528        &self,
529        conversation_id: ConversationId,
530        after_message_id: MessageId,
531        limit: usize,
532    ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
533        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
534
535        let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
536            "SELECT id, role, content FROM messages \
537             WHERE conversation_id = ? AND id > ? \
538             ORDER BY id ASC LIMIT ?",
539        )
540        .bind(conversation_id)
541        .bind(after_message_id)
542        .bind(effective_limit)
543        .fetch_all(&self.pool)
544        .await?;
545
546        Ok(rows)
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    async fn test_store() -> SqliteStore {
555        SqliteStore::new(":memory:").await.unwrap()
556    }
557
558    #[tokio::test]
559    async fn create_conversation_returns_id() {
560        let store = test_store().await;
561        let id1 = store.create_conversation().await.unwrap();
562        let id2 = store.create_conversation().await.unwrap();
563        assert_eq!(id1, ConversationId(1));
564        assert_eq!(id2, ConversationId(2));
565    }
566
567    #[tokio::test]
568    async fn save_and_load_messages() {
569        let store = test_store().await;
570        let cid = store.create_conversation().await.unwrap();
571
572        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
573        let msg_id2 = store
574            .save_message(cid, "assistant", "hi there")
575            .await
576            .unwrap();
577
578        assert_eq!(msg_id1, MessageId(1));
579        assert_eq!(msg_id2, MessageId(2));
580
581        let history = store.load_history(cid, 50).await.unwrap();
582        assert_eq!(history.len(), 2);
583        assert_eq!(history[0].role, Role::User);
584        assert_eq!(history[0].content, "hello");
585        assert_eq!(history[1].role, Role::Assistant);
586        assert_eq!(history[1].content, "hi there");
587    }
588
589    #[tokio::test]
590    async fn load_history_respects_limit() {
591        let store = test_store().await;
592        let cid = store.create_conversation().await.unwrap();
593
594        for i in 0..10 {
595            store
596                .save_message(cid, "user", &format!("msg {i}"))
597                .await
598                .unwrap();
599        }
600
601        let history = store.load_history(cid, 3).await.unwrap();
602        assert_eq!(history.len(), 3);
603        assert_eq!(history[0].content, "msg 7");
604        assert_eq!(history[1].content, "msg 8");
605        assert_eq!(history[2].content, "msg 9");
606    }
607
608    #[tokio::test]
609    async fn latest_conversation_id_empty() {
610        let store = test_store().await;
611        assert!(store.latest_conversation_id().await.unwrap().is_none());
612    }
613
614    #[tokio::test]
615    async fn latest_conversation_id_returns_newest() {
616        let store = test_store().await;
617        store.create_conversation().await.unwrap();
618        let id2 = store.create_conversation().await.unwrap();
619        assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
620    }
621
622    #[tokio::test]
623    async fn messages_isolated_per_conversation() {
624        let store = test_store().await;
625        let cid1 = store.create_conversation().await.unwrap();
626        let cid2 = store.create_conversation().await.unwrap();
627
628        store.save_message(cid1, "user", "conv1").await.unwrap();
629        store.save_message(cid2, "user", "conv2").await.unwrap();
630
631        let h1 = store.load_history(cid1, 50).await.unwrap();
632        let h2 = store.load_history(cid2, 50).await.unwrap();
633        assert_eq!(h1.len(), 1);
634        assert_eq!(h1[0].content, "conv1");
635        assert_eq!(h2.len(), 1);
636        assert_eq!(h2[0].content, "conv2");
637    }
638
639    #[tokio::test]
640    async fn pool_accessor_returns_valid_pool() {
641        let store = test_store().await;
642        let pool = store.pool();
643        let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
644        assert_eq!(row.0, 1);
645    }
646
647    #[tokio::test]
648    async fn embeddings_metadata_table_exists() {
649        let store = test_store().await;
650        let result: (i64,) = sqlx::query_as(
651            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
652        )
653        .fetch_one(store.pool())
654        .await
655        .unwrap();
656        assert_eq!(result.0, 1);
657    }
658
659    #[tokio::test]
660    async fn cascade_delete_removes_embeddings_metadata() {
661        let store = test_store().await;
662        let pool = store.pool();
663
664        let cid = store.create_conversation().await.unwrap();
665        let msg_id = store.save_message(cid, "user", "test").await.unwrap();
666
667        let point_id = uuid::Uuid::new_v4().to_string();
668        sqlx::query(
669            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
670             VALUES (?, ?, ?)",
671        )
672        .bind(msg_id)
673        .bind(&point_id)
674        .bind(768_i64)
675        .execute(pool)
676        .await
677        .unwrap();
678
679        let before: (i64,) =
680            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
681                .bind(msg_id)
682                .fetch_one(pool)
683                .await
684                .unwrap();
685        assert_eq!(before.0, 1);
686
687        sqlx::query("DELETE FROM messages WHERE id = ?")
688            .bind(msg_id)
689            .execute(pool)
690            .await
691            .unwrap();
692
693        let after: (i64,) =
694            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
695                .bind(msg_id)
696                .fetch_one(pool)
697                .await
698                .unwrap();
699        assert_eq!(after.0, 0);
700    }
701
702    #[tokio::test]
703    async fn messages_by_ids_batch_fetch() {
704        let store = test_store().await;
705        let cid = store.create_conversation().await.unwrap();
706        let id1 = store.save_message(cid, "user", "hello").await.unwrap();
707        let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
708        let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
709
710        let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
711        assert_eq!(results.len(), 2);
712        assert_eq!(results[0].0, id1);
713        assert_eq!(results[0].1.content, "hello");
714        assert_eq!(results[1].0, id2);
715        assert_eq!(results[1].1.content, "hi");
716    }
717
718    #[tokio::test]
719    async fn messages_by_ids_empty_input() {
720        let store = test_store().await;
721        let results = store.messages_by_ids(&[]).await.unwrap();
722        assert!(results.is_empty());
723    }
724
725    #[tokio::test]
726    async fn messages_by_ids_nonexistent() {
727        let store = test_store().await;
728        let results = store
729            .messages_by_ids(&[MessageId(999), MessageId(1000)])
730            .await
731            .unwrap();
732        assert!(results.is_empty());
733    }
734
735    #[tokio::test]
736    async fn message_by_id_fetches_existing() {
737        let store = test_store().await;
738        let cid = store.create_conversation().await.unwrap();
739        let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
740
741        let msg = store.message_by_id(msg_id).await.unwrap();
742        assert!(msg.is_some());
743        let msg = msg.unwrap();
744        assert_eq!(msg.role, Role::User);
745        assert_eq!(msg.content, "hello");
746    }
747
748    #[tokio::test]
749    async fn message_by_id_returns_none_for_nonexistent() {
750        let store = test_store().await;
751        let msg = store.message_by_id(MessageId(999)).await.unwrap();
752        assert!(msg.is_none());
753    }
754
755    #[tokio::test]
756    async fn unembedded_message_ids_returns_all_when_none_embedded() {
757        let store = test_store().await;
758        let cid = store.create_conversation().await.unwrap();
759
760        store.save_message(cid, "user", "msg1").await.unwrap();
761        store.save_message(cid, "assistant", "msg2").await.unwrap();
762
763        let unembedded = store.unembedded_message_ids(None).await.unwrap();
764        assert_eq!(unembedded.len(), 2);
765        assert_eq!(unembedded[0].3, "msg1");
766        assert_eq!(unembedded[1].3, "msg2");
767    }
768
769    #[tokio::test]
770    async fn unembedded_message_ids_excludes_embedded() {
771        let store = test_store().await;
772        let pool = store.pool();
773        let cid = store.create_conversation().await.unwrap();
774
775        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
776        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
777
778        let point_id = uuid::Uuid::new_v4().to_string();
779        sqlx::query(
780            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
781             VALUES (?, ?, ?)",
782        )
783        .bind(msg_id1)
784        .bind(&point_id)
785        .bind(768_i64)
786        .execute(pool)
787        .await
788        .unwrap();
789
790        let unembedded = store.unembedded_message_ids(None).await.unwrap();
791        assert_eq!(unembedded.len(), 1);
792        assert_eq!(unembedded[0].0, msg_id2);
793        assert_eq!(unembedded[0].3, "msg2");
794    }
795
796    #[tokio::test]
797    async fn unembedded_message_ids_respects_limit() {
798        let store = test_store().await;
799        let cid = store.create_conversation().await.unwrap();
800
801        for i in 0..10 {
802            store
803                .save_message(cid, "user", &format!("msg{i}"))
804                .await
805                .unwrap();
806        }
807
808        let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
809        assert_eq!(unembedded.len(), 3);
810    }
811
812    #[tokio::test]
813    async fn count_messages_returns_correct_count() {
814        let store = test_store().await;
815        let cid = store.create_conversation().await.unwrap();
816
817        assert_eq!(store.count_messages(cid).await.unwrap(), 0);
818
819        store.save_message(cid, "user", "msg1").await.unwrap();
820        store.save_message(cid, "assistant", "msg2").await.unwrap();
821
822        assert_eq!(store.count_messages(cid).await.unwrap(), 2);
823    }
824
825    #[tokio::test]
826    async fn count_messages_after_filters_correctly() {
827        let store = test_store().await;
828        let cid = store.create_conversation().await.unwrap();
829
830        let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
831        let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
832        let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
833
834        assert_eq!(
835            store.count_messages_after(cid, MessageId(0)).await.unwrap(),
836            3
837        );
838        assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
839        assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
840    }
841
842    #[tokio::test]
843    async fn load_messages_range_basic() {
844        let store = test_store().await;
845        let cid = store.create_conversation().await.unwrap();
846
847        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
848        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
849        let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
850
851        let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
852        assert_eq!(msgs.len(), 2);
853        assert_eq!(msgs[0].0, msg_id2);
854        assert_eq!(msgs[0].2, "msg2");
855        assert_eq!(msgs[1].0, msg_id3);
856        assert_eq!(msgs[1].2, "msg3");
857    }
858
859    #[tokio::test]
860    async fn load_messages_range_respects_limit() {
861        let store = test_store().await;
862        let cid = store.create_conversation().await.unwrap();
863
864        store.save_message(cid, "user", "msg1").await.unwrap();
865        store.save_message(cid, "assistant", "msg2").await.unwrap();
866        store.save_message(cid, "user", "msg3").await.unwrap();
867
868        let msgs = store
869            .load_messages_range(cid, MessageId(0), 2)
870            .await
871            .unwrap();
872        assert_eq!(msgs.len(), 2);
873    }
874
875    #[tokio::test]
876    async fn keyword_search_basic() {
877        let store = test_store().await;
878        let cid = store.create_conversation().await.unwrap();
879
880        store
881            .save_message(cid, "user", "rust programming language")
882            .await
883            .unwrap();
884        store
885            .save_message(cid, "assistant", "python is great too")
886            .await
887            .unwrap();
888        store
889            .save_message(cid, "user", "I love rust and cargo")
890            .await
891            .unwrap();
892
893        let results = store.keyword_search("rust", 10, None).await.unwrap();
894        assert_eq!(results.len(), 2);
895        assert!(results.iter().all(|(_, score)| *score > 0.0));
896    }
897
898    #[tokio::test]
899    async fn keyword_search_with_conversation_filter() {
900        let store = test_store().await;
901        let cid1 = store.create_conversation().await.unwrap();
902        let cid2 = store.create_conversation().await.unwrap();
903
904        store
905            .save_message(cid1, "user", "hello world")
906            .await
907            .unwrap();
908        store
909            .save_message(cid2, "user", "hello universe")
910            .await
911            .unwrap();
912
913        let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
914        assert_eq!(results.len(), 1);
915    }
916
917    #[tokio::test]
918    async fn keyword_search_no_match() {
919        let store = test_store().await;
920        let cid = store.create_conversation().await.unwrap();
921
922        store
923            .save_message(cid, "user", "hello world")
924            .await
925            .unwrap();
926
927        let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
928        assert!(results.is_empty());
929    }
930
931    #[tokio::test]
932    async fn keyword_search_respects_limit() {
933        let store = test_store().await;
934        let cid = store.create_conversation().await.unwrap();
935
936        for i in 0..10 {
937            store
938                .save_message(cid, "user", &format!("test message {i}"))
939                .await
940                .unwrap();
941        }
942
943        let results = store.keyword_search("test", 3, None).await.unwrap();
944        assert_eq!(results.len(), 3);
945    }
946
947    #[tokio::test]
948    async fn save_message_with_metadata_stores_visibility() {
949        let store = test_store().await;
950        let cid = store.create_conversation().await.unwrap();
951
952        let id = store
953            .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
954            .await
955            .unwrap();
956
957        let history = store.load_history(cid, 10).await.unwrap();
958        assert_eq!(history.len(), 1);
959        assert!(!history[0].metadata.agent_visible);
960        assert!(history[0].metadata.user_visible);
961        assert_eq!(id, MessageId(1));
962    }
963
964    #[tokio::test]
965    async fn load_history_filtered_by_agent_visible() {
966        let store = test_store().await;
967        let cid = store.create_conversation().await.unwrap();
968
969        store
970            .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
971            .await
972            .unwrap();
973        store
974            .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
975            .await
976            .unwrap();
977
978        let agent_msgs = store
979            .load_history_filtered(cid, 50, Some(true), None)
980            .await
981            .unwrap();
982        assert_eq!(agent_msgs.len(), 1);
983        assert_eq!(agent_msgs[0].content, "visible to agent");
984    }
985
986    #[tokio::test]
987    async fn load_history_filtered_by_user_visible() {
988        let store = test_store().await;
989        let cid = store.create_conversation().await.unwrap();
990
991        store
992            .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
993            .await
994            .unwrap();
995        store
996            .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
997            .await
998            .unwrap();
999
1000        let user_msgs = store
1001            .load_history_filtered(cid, 50, None, Some(true))
1002            .await
1003            .unwrap();
1004        assert_eq!(user_msgs.len(), 1);
1005        assert_eq!(user_msgs[0].content, "user sees this");
1006    }
1007
1008    #[tokio::test]
1009    async fn load_history_filtered_no_filter_returns_all() {
1010        let store = test_store().await;
1011        let cid = store.create_conversation().await.unwrap();
1012
1013        store
1014            .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1015            .await
1016            .unwrap();
1017        store
1018            .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1019            .await
1020            .unwrap();
1021
1022        let all_msgs = store
1023            .load_history_filtered(cid, 50, None, None)
1024            .await
1025            .unwrap();
1026        assert_eq!(all_msgs.len(), 2);
1027    }
1028
1029    #[tokio::test]
1030    async fn replace_conversation_marks_originals_and_inserts_summary() {
1031        let store = test_store().await;
1032        let cid = store.create_conversation().await.unwrap();
1033
1034        let id1 = store.save_message(cid, "user", "first").await.unwrap();
1035        let id2 = store
1036            .save_message(cid, "assistant", "second")
1037            .await
1038            .unwrap();
1039        let id3 = store.save_message(cid, "user", "third").await.unwrap();
1040
1041        let summary_id = store
1042            .replace_conversation(cid, id1..=id2, "system", "summary text")
1043            .await
1044            .unwrap();
1045
1046        // Original messages should be user_only
1047        let all = store.load_history(cid, 50).await.unwrap();
1048        // id1 and id2 marked agent_visible=false, id3 untouched, summary inserted
1049        let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1050        assert!(!by_id1.metadata.agent_visible);
1051        assert!(by_id1.metadata.user_visible);
1052
1053        let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1054        assert!(!by_id2.metadata.agent_visible);
1055
1056        let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1057        assert!(by_id3.metadata.agent_visible);
1058
1059        // Summary is agent_only (agent_visible=1, user_visible=0)
1060        let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1061        assert!(summary.metadata.agent_visible);
1062        assert!(!summary.metadata.user_visible);
1063        assert!(summary_id > id3);
1064    }
1065
1066    #[tokio::test]
1067    async fn oldest_message_ids_returns_in_order() {
1068        let store = test_store().await;
1069        let cid = store.create_conversation().await.unwrap();
1070
1071        let id1 = store.save_message(cid, "user", "a").await.unwrap();
1072        let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1073        let id3 = store.save_message(cid, "user", "c").await.unwrap();
1074
1075        let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1076        assert_eq!(ids, vec![id1, id2]);
1077        assert!(ids[0] < ids[1]);
1078
1079        let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1080        assert_eq!(all_ids, vec![id1, id2, id3]);
1081    }
1082
1083    #[tokio::test]
1084    async fn message_metadata_default_both_visible() {
1085        let store = test_store().await;
1086        let cid = store.create_conversation().await.unwrap();
1087
1088        store.save_message(cid, "user", "normal").await.unwrap();
1089
1090        let history = store.load_history(cid, 10).await.unwrap();
1091        assert!(history[0].metadata.agent_visible);
1092        assert!(history[0].metadata.user_visible);
1093        assert!(history[0].metadata.compacted_at.is_none());
1094    }
1095
1096    #[tokio::test]
1097    async fn load_history_empty_parts_json_fast_path() {
1098        let store = test_store().await;
1099        let cid = store.create_conversation().await.unwrap();
1100
1101        store
1102            .save_message_with_parts(cid, "user", "hello", "[]")
1103            .await
1104            .unwrap();
1105
1106        let history = store.load_history(cid, 10).await.unwrap();
1107        assert_eq!(history.len(), 1);
1108        assert!(
1109            history[0].parts.is_empty(),
1110            "\"[]\" fast-path must yield empty parts Vec"
1111        );
1112    }
1113
1114    #[tokio::test]
1115    async fn load_history_non_empty_parts_json_parsed() {
1116        let store = test_store().await;
1117        let cid = store.create_conversation().await.unwrap();
1118
1119        let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1120            tool_use_id: "t1".into(),
1121            content: "result".into(),
1122            is_error: false,
1123        }])
1124        .unwrap();
1125
1126        store
1127            .save_message_with_parts(cid, "user", "hello", &parts_json)
1128            .await
1129            .unwrap();
1130
1131        let history = store.load_history(cid, 10).await.unwrap();
1132        assert_eq!(history.len(), 1);
1133        assert_eq!(history[0].parts.len(), 1);
1134        assert!(
1135            matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1136        );
1137    }
1138
1139    #[tokio::test]
1140    async fn message_by_id_empty_parts_json_fast_path() {
1141        let store = test_store().await;
1142        let cid = store.create_conversation().await.unwrap();
1143
1144        let id = store
1145            .save_message_with_parts(cid, "user", "msg", "[]")
1146            .await
1147            .unwrap();
1148
1149        let msg = store.message_by_id(id).await.unwrap().unwrap();
1150        assert!(
1151            msg.parts.is_empty(),
1152            "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1153        );
1154    }
1155
1156    #[tokio::test]
1157    async fn messages_by_ids_empty_parts_json_fast_path() {
1158        let store = test_store().await;
1159        let cid = store.create_conversation().await.unwrap();
1160
1161        let id = store
1162            .save_message_with_parts(cid, "user", "msg", "[]")
1163            .await
1164            .unwrap();
1165
1166        let results = store.messages_by_ids(&[id]).await.unwrap();
1167        assert_eq!(results.len(), 1);
1168        assert!(
1169            results[0].1.parts.is_empty(),
1170            "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1171        );
1172    }
1173
1174    #[tokio::test]
1175    async fn load_history_filtered_empty_parts_json_fast_path() {
1176        let store = test_store().await;
1177        let cid = store.create_conversation().await.unwrap();
1178
1179        store
1180            .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1181            .await
1182            .unwrap();
1183
1184        let msgs = store
1185            .load_history_filtered(cid, 10, Some(true), None)
1186            .await
1187            .unwrap();
1188        assert_eq!(msgs.len(), 1);
1189        assert!(
1190            msgs[0].parts.is_empty(),
1191            "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1192        );
1193    }
1194}