Skip to main content

zeph_memory/sqlite/
messages.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10/// Sanitize an arbitrary string into a valid FTS5 query.
11///
12/// Extracts alphanumeric tokens and joins them with spaces. FTS5 special
13/// characters (`,`, `"`, `*`, `(`, `)`, `^`, `-`, `+`) are stripped to avoid
14/// syntax errors in the `MATCH` clause.
15fn sanitize_fts5_query(query: &str) -> String {
16    query
17        .split(|c: char| !c.is_alphanumeric())
18        .filter(|t| !t.is_empty())
19        .collect::<Vec<_>>()
20        .join(" ")
21}
22
23fn parse_role(s: &str) -> Role {
24    match s {
25        "assistant" => Role::Assistant,
26        "system" => Role::System,
27        _ => Role::User,
28    }
29}
30
31#[must_use]
32pub fn role_str(role: Role) -> &'static str {
33    match role {
34        Role::System => "system",
35        Role::User => "user",
36        Role::Assistant => "assistant",
37    }
38}
39
40impl SqliteStore {
41    /// Create a new conversation and return its ID.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the insert fails.
46    pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
47        let row: (ConversationId,) =
48            sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
49                .fetch_one(&self.pool)
50                .await?;
51        Ok(row.0)
52    }
53
54    /// Save a message to the given conversation and return the message ID.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the insert fails.
59    pub async fn save_message(
60        &self,
61        conversation_id: ConversationId,
62        role: &str,
63        content: &str,
64    ) -> Result<MessageId, MemoryError> {
65        self.save_message_with_parts(conversation_id, role, content, "[]")
66            .await
67    }
68
69    /// Save a message with structured parts JSON.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if the insert fails.
74    pub async fn save_message_with_parts(
75        &self,
76        conversation_id: ConversationId,
77        role: &str,
78        content: &str,
79        parts_json: &str,
80    ) -> Result<MessageId, MemoryError> {
81        self.save_message_with_metadata(conversation_id, role, content, parts_json, true, true)
82            .await
83    }
84
85    /// Save a message with visibility metadata.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the insert fails.
90    pub async fn save_message_with_metadata(
91        &self,
92        conversation_id: ConversationId,
93        role: &str,
94        content: &str,
95        parts_json: &str,
96        agent_visible: bool,
97        user_visible: bool,
98    ) -> Result<MessageId, MemoryError> {
99        let row: (MessageId,) = sqlx::query_as(
100            "INSERT INTO messages (conversation_id, role, content, parts, agent_visible, user_visible) \
101             VALUES (?, ?, ?, ?, ?, ?) RETURNING id",
102        )
103        .bind(conversation_id)
104        .bind(role)
105        .bind(content)
106        .bind(parts_json)
107        .bind(i64::from(agent_visible))
108        .bind(i64::from(user_visible))
109        .fetch_one(&self.pool)
110        .await?;
111        Ok(row.0)
112    }
113
114    /// Load the most recent messages for a conversation, up to `limit`.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the query fails.
119    pub async fn load_history(
120        &self,
121        conversation_id: ConversationId,
122        limit: u32,
123    ) -> Result<Vec<Message>, MemoryError> {
124        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
125            "SELECT role, content, parts, agent_visible, user_visible FROM (\
126                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
127                WHERE conversation_id = ? \
128                ORDER BY id DESC \
129                LIMIT ?\
130             ) ORDER BY id ASC",
131        )
132        .bind(conversation_id)
133        .bind(limit)
134        .fetch_all(&self.pool)
135        .await?;
136
137        let messages = rows
138            .into_iter()
139            .map(
140                |(role_str, content, parts_json, agent_visible, user_visible)| {
141                    let parts: Vec<MessagePart> = if parts_json == "[]" {
142                        vec![]
143                    } else {
144                        serde_json::from_str(&parts_json).unwrap_or_default()
145                    };
146                    Message {
147                        role: parse_role(&role_str),
148                        content,
149                        parts,
150                        metadata: MessageMetadata {
151                            agent_visible: agent_visible != 0,
152                            user_visible: user_visible != 0,
153                            compacted_at: None,
154                        },
155                    }
156                },
157            )
158            .collect();
159        Ok(messages)
160    }
161
162    /// Load messages filtered by visibility flags.
163    ///
164    /// Pass `Some(true)` to filter by a flag, `None` to skip filtering.
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if the query fails.
169    pub async fn load_history_filtered(
170        &self,
171        conversation_id: ConversationId,
172        limit: u32,
173        agent_visible: Option<bool>,
174        user_visible: Option<bool>,
175    ) -> Result<Vec<Message>, MemoryError> {
176        let av = agent_visible.map(i64::from);
177        let uv = user_visible.map(i64::from);
178
179        let rows: Vec<(String, String, String, i64, i64)> = sqlx::query_as(
180            "WITH recent AS (\
181                SELECT role, content, parts, agent_visible, user_visible, id FROM messages \
182                WHERE conversation_id = ? \
183                  AND (? IS NULL OR agent_visible = ?) \
184                  AND (? IS NULL OR user_visible = ?) \
185                ORDER BY id DESC \
186                LIMIT ?\
187             ) SELECT role, content, parts, agent_visible, user_visible FROM recent ORDER BY id ASC",
188        )
189        .bind(conversation_id)
190        .bind(av)
191        .bind(av)
192        .bind(uv)
193        .bind(uv)
194        .bind(limit)
195        .fetch_all(&self.pool)
196        .await?;
197
198        let messages = rows
199            .into_iter()
200            .map(
201                |(role_str, content, parts_json, agent_visible, user_visible)| {
202                    let parts: Vec<MessagePart> = if parts_json == "[]" {
203                        vec![]
204                    } else {
205                        serde_json::from_str(&parts_json).unwrap_or_default()
206                    };
207                    Message {
208                        role: parse_role(&role_str),
209                        content,
210                        parts,
211                        metadata: MessageMetadata {
212                            agent_visible: agent_visible != 0,
213                            user_visible: user_visible != 0,
214                            compacted_at: None,
215                        },
216                    }
217                },
218            )
219            .collect();
220        Ok(messages)
221    }
222
223    /// Atomically mark a range of messages as user-only and insert a summary as agent-only.
224    ///
225    /// Within a single transaction:
226    /// 1. Updates `agent_visible=0, compacted_at=now` for messages in `compacted_range`.
227    /// 2. Inserts `summary_content` with `agent_visible=1, user_visible=0`.
228    ///
229    /// Returns the `MessageId` of the inserted summary.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if the transaction fails.
234    pub async fn replace_conversation(
235        &self,
236        conversation_id: ConversationId,
237        compacted_range: std::ops::RangeInclusive<MessageId>,
238        summary_role: &str,
239        summary_content: &str,
240    ) -> Result<MessageId, MemoryError> {
241        let now = {
242            let secs = std::time::SystemTime::now()
243                .duration_since(std::time::UNIX_EPOCH)
244                .unwrap_or_default()
245                .as_secs();
246            format!("{secs}")
247        };
248        let start_id = compacted_range.start().0;
249        let end_id = compacted_range.end().0;
250
251        let mut tx = self.pool.begin().await?;
252
253        sqlx::query(
254            "UPDATE messages SET agent_visible = 0, compacted_at = ? \
255             WHERE conversation_id = ? AND id >= ? AND id <= ?",
256        )
257        .bind(&now)
258        .bind(conversation_id)
259        .bind(start_id)
260        .bind(end_id)
261        .execute(&mut *tx)
262        .await?;
263
264        let row: (MessageId,) = sqlx::query_as(
265            "INSERT INTO messages \
266             (conversation_id, role, content, parts, agent_visible, user_visible) \
267             VALUES (?, ?, ?, '[]', 1, 0) RETURNING id",
268        )
269        .bind(conversation_id)
270        .bind(summary_role)
271        .bind(summary_content)
272        .fetch_one(&mut *tx)
273        .await?;
274
275        tx.commit().await?;
276
277        Ok(row.0)
278    }
279
280    /// Return the IDs of the N oldest messages in a conversation (ascending order).
281    ///
282    /// # Errors
283    ///
284    /// Returns an error if the query fails.
285    pub async fn oldest_message_ids(
286        &self,
287        conversation_id: ConversationId,
288        n: u32,
289    ) -> Result<Vec<MessageId>, MemoryError> {
290        let rows: Vec<(MessageId,)> = sqlx::query_as(
291            "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id ASC LIMIT ?",
292        )
293        .bind(conversation_id)
294        .bind(n)
295        .fetch_all(&self.pool)
296        .await?;
297        Ok(rows.into_iter().map(|r| r.0).collect())
298    }
299
300    /// Return the ID of the most recent conversation, if any.
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if the query fails.
305    pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
306        let row: Option<(ConversationId,)> =
307            sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
308                .fetch_optional(&self.pool)
309                .await?;
310        Ok(row.map(|r| r.0))
311    }
312
313    /// Fetch a single message by its ID.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the query fails.
318    pub async fn message_by_id(
319        &self,
320        message_id: MessageId,
321    ) -> Result<Option<Message>, MemoryError> {
322        let row: Option<(String, String, String, i64, i64)> = sqlx::query_as(
323            "SELECT role, content, parts, agent_visible, user_visible FROM messages WHERE id = ?",
324        )
325        .bind(message_id)
326        .fetch_optional(&self.pool)
327        .await?;
328
329        Ok(row.map(
330            |(role_str, content, parts_json, agent_visible, user_visible)| {
331                let parts: Vec<MessagePart> = if parts_json == "[]" {
332                    vec![]
333                } else {
334                    serde_json::from_str(&parts_json).unwrap_or_default()
335                };
336                Message {
337                    role: parse_role(&role_str),
338                    content,
339                    parts,
340                    metadata: MessageMetadata {
341                        agent_visible: agent_visible != 0,
342                        user_visible: user_visible != 0,
343                        compacted_at: None,
344                    },
345                }
346            },
347        ))
348    }
349
350    /// Fetch messages by a list of IDs in a single query.
351    ///
352    /// # Errors
353    ///
354    /// Returns an error if the query fails.
355    pub async fn messages_by_ids(
356        &self,
357        ids: &[MessageId],
358    ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
359        if ids.is_empty() {
360            return Ok(Vec::new());
361        }
362
363        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
364
365        let query = format!(
366            "SELECT id, role, content, parts FROM messages \
367             WHERE id IN ({placeholders}) AND agent_visible = 1"
368        );
369        let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
370        for &id in ids {
371            q = q.bind(id);
372        }
373
374        let rows = q.fetch_all(&self.pool).await?;
375
376        Ok(rows
377            .into_iter()
378            .map(|(id, role_str, content, parts_json)| {
379                let parts: Vec<MessagePart> = if parts_json == "[]" {
380                    vec![]
381                } else {
382                    serde_json::from_str(&parts_json).unwrap_or_default()
383                };
384                (
385                    id,
386                    Message {
387                        role: parse_role(&role_str),
388                        content,
389                        parts,
390                        metadata: MessageMetadata::default(),
391                    },
392                )
393            })
394            .collect())
395    }
396
397    /// Return message IDs and content for messages without embeddings.
398    ///
399    /// # Errors
400    ///
401    /// Returns an error if the query fails.
402    pub async fn unembedded_message_ids(
403        &self,
404        limit: Option<usize>,
405    ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
406        let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
407
408        let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
409            "SELECT m.id, m.conversation_id, m.role, m.content \
410             FROM messages m \
411             LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
412             WHERE em.id IS NULL \
413             ORDER BY m.id ASC \
414             LIMIT ?",
415        )
416        .bind(effective_limit)
417        .fetch_all(&self.pool)
418        .await?;
419
420        Ok(rows)
421    }
422
423    /// Count the number of messages in a conversation.
424    ///
425    /// # Errors
426    ///
427    /// Returns an error if the query fails.
428    pub async fn count_messages(
429        &self,
430        conversation_id: ConversationId,
431    ) -> Result<i64, MemoryError> {
432        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
433            .bind(conversation_id)
434            .fetch_one(&self.pool)
435            .await?;
436        Ok(row.0)
437    }
438
439    /// Count messages in a conversation with id greater than `after_id`.
440    ///
441    /// # Errors
442    ///
443    /// Returns an error if the query fails.
444    pub async fn count_messages_after(
445        &self,
446        conversation_id: ConversationId,
447        after_id: MessageId,
448    ) -> Result<i64, MemoryError> {
449        let row: (i64,) =
450            sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
451                .bind(conversation_id)
452                .bind(after_id)
453                .fetch_one(&self.pool)
454                .await?;
455        Ok(row.0)
456    }
457
458    /// Full-text keyword search over messages using FTS5.
459    ///
460    /// Returns message IDs with BM25 relevance scores (lower = more relevant,
461    /// negated to positive for consistency with vector scores).
462    ///
463    /// # Errors
464    ///
465    /// Returns an error if the query fails.
466    pub async fn keyword_search(
467        &self,
468        query: &str,
469        limit: usize,
470        conversation_id: Option<ConversationId>,
471    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
472        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
473        let safe_query = sanitize_fts5_query(query);
474        if safe_query.is_empty() {
475            return Ok(Vec::new());
476        }
477
478        let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
479            sqlx::query_as(
480                "SELECT m.id, -rank AS score \
481                 FROM messages_fts f \
482                 JOIN messages m ON m.id = f.rowid \
483                 WHERE messages_fts MATCH ? AND m.conversation_id = ? AND m.agent_visible = 1 \
484                 ORDER BY rank \
485                 LIMIT ?",
486            )
487            .bind(&safe_query)
488            .bind(cid)
489            .bind(effective_limit)
490            .fetch_all(&self.pool)
491            .await?
492        } else {
493            sqlx::query_as(
494                "SELECT m.id, -rank AS score \
495                 FROM messages_fts f \
496                 JOIN messages m ON m.id = f.rowid \
497                 WHERE messages_fts MATCH ? AND m.agent_visible = 1 \
498                 ORDER BY rank \
499                 LIMIT ?",
500            )
501            .bind(&safe_query)
502            .bind(effective_limit)
503            .fetch_all(&self.pool)
504            .await?
505        };
506
507        Ok(rows)
508    }
509
510    /// Fetch creation timestamps (Unix epoch seconds) for the given message IDs.
511    ///
512    /// Messages without a `created_at` column fall back to 0.
513    ///
514    /// # Errors
515    ///
516    /// Returns an error if the query fails.
517    pub async fn message_timestamps(
518        &self,
519        ids: &[MessageId],
520    ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
521        if ids.is_empty() {
522            return Ok(std::collections::HashMap::new());
523        }
524
525        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
526        let query = format!(
527            "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
528             FROM messages WHERE id IN ({placeholders})"
529        );
530        let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
531        for &id in ids {
532            q = q.bind(id);
533        }
534
535        let rows = q.fetch_all(&self.pool).await?;
536        Ok(rows.into_iter().collect())
537    }
538
539    /// Load a range of messages after a given message ID.
540    ///
541    /// # Errors
542    ///
543    /// Returns an error if the query fails.
544    pub async fn load_messages_range(
545        &self,
546        conversation_id: ConversationId,
547        after_message_id: MessageId,
548        limit: usize,
549    ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
550        let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
551
552        let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
553            "SELECT id, role, content FROM messages \
554             WHERE conversation_id = ? AND id > ? \
555             ORDER BY id ASC LIMIT ?",
556        )
557        .bind(conversation_id)
558        .bind(after_message_id)
559        .bind(effective_limit)
560        .fetch_all(&self.pool)
561        .await?;
562
563        Ok(rows)
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    async fn test_store() -> SqliteStore {
572        SqliteStore::new(":memory:").await.unwrap()
573    }
574
575    #[tokio::test]
576    async fn create_conversation_returns_id() {
577        let store = test_store().await;
578        let id1 = store.create_conversation().await.unwrap();
579        let id2 = store.create_conversation().await.unwrap();
580        assert_eq!(id1, ConversationId(1));
581        assert_eq!(id2, ConversationId(2));
582    }
583
584    #[tokio::test]
585    async fn save_and_load_messages() {
586        let store = test_store().await;
587        let cid = store.create_conversation().await.unwrap();
588
589        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
590        let msg_id2 = store
591            .save_message(cid, "assistant", "hi there")
592            .await
593            .unwrap();
594
595        assert_eq!(msg_id1, MessageId(1));
596        assert_eq!(msg_id2, MessageId(2));
597
598        let history = store.load_history(cid, 50).await.unwrap();
599        assert_eq!(history.len(), 2);
600        assert_eq!(history[0].role, Role::User);
601        assert_eq!(history[0].content, "hello");
602        assert_eq!(history[1].role, Role::Assistant);
603        assert_eq!(history[1].content, "hi there");
604    }
605
606    #[tokio::test]
607    async fn load_history_respects_limit() {
608        let store = test_store().await;
609        let cid = store.create_conversation().await.unwrap();
610
611        for i in 0..10 {
612            store
613                .save_message(cid, "user", &format!("msg {i}"))
614                .await
615                .unwrap();
616        }
617
618        let history = store.load_history(cid, 3).await.unwrap();
619        assert_eq!(history.len(), 3);
620        assert_eq!(history[0].content, "msg 7");
621        assert_eq!(history[1].content, "msg 8");
622        assert_eq!(history[2].content, "msg 9");
623    }
624
625    #[tokio::test]
626    async fn latest_conversation_id_empty() {
627        let store = test_store().await;
628        assert!(store.latest_conversation_id().await.unwrap().is_none());
629    }
630
631    #[tokio::test]
632    async fn latest_conversation_id_returns_newest() {
633        let store = test_store().await;
634        store.create_conversation().await.unwrap();
635        let id2 = store.create_conversation().await.unwrap();
636        assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
637    }
638
639    #[tokio::test]
640    async fn messages_isolated_per_conversation() {
641        let store = test_store().await;
642        let cid1 = store.create_conversation().await.unwrap();
643        let cid2 = store.create_conversation().await.unwrap();
644
645        store.save_message(cid1, "user", "conv1").await.unwrap();
646        store.save_message(cid2, "user", "conv2").await.unwrap();
647
648        let h1 = store.load_history(cid1, 50).await.unwrap();
649        let h2 = store.load_history(cid2, 50).await.unwrap();
650        assert_eq!(h1.len(), 1);
651        assert_eq!(h1[0].content, "conv1");
652        assert_eq!(h2.len(), 1);
653        assert_eq!(h2[0].content, "conv2");
654    }
655
656    #[tokio::test]
657    async fn pool_accessor_returns_valid_pool() {
658        let store = test_store().await;
659        let pool = store.pool();
660        let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
661        assert_eq!(row.0, 1);
662    }
663
664    #[tokio::test]
665    async fn embeddings_metadata_table_exists() {
666        let store = test_store().await;
667        let result: (i64,) = sqlx::query_as(
668            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
669        )
670        .fetch_one(store.pool())
671        .await
672        .unwrap();
673        assert_eq!(result.0, 1);
674    }
675
676    #[tokio::test]
677    async fn cascade_delete_removes_embeddings_metadata() {
678        let store = test_store().await;
679        let pool = store.pool();
680
681        let cid = store.create_conversation().await.unwrap();
682        let msg_id = store.save_message(cid, "user", "test").await.unwrap();
683
684        let point_id = uuid::Uuid::new_v4().to_string();
685        sqlx::query(
686            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
687             VALUES (?, ?, ?)",
688        )
689        .bind(msg_id)
690        .bind(&point_id)
691        .bind(768_i64)
692        .execute(pool)
693        .await
694        .unwrap();
695
696        let before: (i64,) =
697            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
698                .bind(msg_id)
699                .fetch_one(pool)
700                .await
701                .unwrap();
702        assert_eq!(before.0, 1);
703
704        sqlx::query("DELETE FROM messages WHERE id = ?")
705            .bind(msg_id)
706            .execute(pool)
707            .await
708            .unwrap();
709
710        let after: (i64,) =
711            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
712                .bind(msg_id)
713                .fetch_one(pool)
714                .await
715                .unwrap();
716        assert_eq!(after.0, 0);
717    }
718
719    #[tokio::test]
720    async fn messages_by_ids_batch_fetch() {
721        let store = test_store().await;
722        let cid = store.create_conversation().await.unwrap();
723        let id1 = store.save_message(cid, "user", "hello").await.unwrap();
724        let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
725        let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
726
727        let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
728        assert_eq!(results.len(), 2);
729        assert_eq!(results[0].0, id1);
730        assert_eq!(results[0].1.content, "hello");
731        assert_eq!(results[1].0, id2);
732        assert_eq!(results[1].1.content, "hi");
733    }
734
735    #[tokio::test]
736    async fn messages_by_ids_empty_input() {
737        let store = test_store().await;
738        let results = store.messages_by_ids(&[]).await.unwrap();
739        assert!(results.is_empty());
740    }
741
742    #[tokio::test]
743    async fn messages_by_ids_nonexistent() {
744        let store = test_store().await;
745        let results = store
746            .messages_by_ids(&[MessageId(999), MessageId(1000)])
747            .await
748            .unwrap();
749        assert!(results.is_empty());
750    }
751
752    #[tokio::test]
753    async fn message_by_id_fetches_existing() {
754        let store = test_store().await;
755        let cid = store.create_conversation().await.unwrap();
756        let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
757
758        let msg = store.message_by_id(msg_id).await.unwrap();
759        assert!(msg.is_some());
760        let msg = msg.unwrap();
761        assert_eq!(msg.role, Role::User);
762        assert_eq!(msg.content, "hello");
763    }
764
765    #[tokio::test]
766    async fn message_by_id_returns_none_for_nonexistent() {
767        let store = test_store().await;
768        let msg = store.message_by_id(MessageId(999)).await.unwrap();
769        assert!(msg.is_none());
770    }
771
772    #[tokio::test]
773    async fn unembedded_message_ids_returns_all_when_none_embedded() {
774        let store = test_store().await;
775        let cid = store.create_conversation().await.unwrap();
776
777        store.save_message(cid, "user", "msg1").await.unwrap();
778        store.save_message(cid, "assistant", "msg2").await.unwrap();
779
780        let unembedded = store.unembedded_message_ids(None).await.unwrap();
781        assert_eq!(unembedded.len(), 2);
782        assert_eq!(unembedded[0].3, "msg1");
783        assert_eq!(unembedded[1].3, "msg2");
784    }
785
786    #[tokio::test]
787    async fn unembedded_message_ids_excludes_embedded() {
788        let store = test_store().await;
789        let pool = store.pool();
790        let cid = store.create_conversation().await.unwrap();
791
792        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
793        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
794
795        let point_id = uuid::Uuid::new_v4().to_string();
796        sqlx::query(
797            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
798             VALUES (?, ?, ?)",
799        )
800        .bind(msg_id1)
801        .bind(&point_id)
802        .bind(768_i64)
803        .execute(pool)
804        .await
805        .unwrap();
806
807        let unembedded = store.unembedded_message_ids(None).await.unwrap();
808        assert_eq!(unembedded.len(), 1);
809        assert_eq!(unembedded[0].0, msg_id2);
810        assert_eq!(unembedded[0].3, "msg2");
811    }
812
813    #[tokio::test]
814    async fn unembedded_message_ids_respects_limit() {
815        let store = test_store().await;
816        let cid = store.create_conversation().await.unwrap();
817
818        for i in 0..10 {
819            store
820                .save_message(cid, "user", &format!("msg{i}"))
821                .await
822                .unwrap();
823        }
824
825        let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
826        assert_eq!(unembedded.len(), 3);
827    }
828
829    #[tokio::test]
830    async fn count_messages_returns_correct_count() {
831        let store = test_store().await;
832        let cid = store.create_conversation().await.unwrap();
833
834        assert_eq!(store.count_messages(cid).await.unwrap(), 0);
835
836        store.save_message(cid, "user", "msg1").await.unwrap();
837        store.save_message(cid, "assistant", "msg2").await.unwrap();
838
839        assert_eq!(store.count_messages(cid).await.unwrap(), 2);
840    }
841
842    #[tokio::test]
843    async fn count_messages_after_filters_correctly() {
844        let store = test_store().await;
845        let cid = store.create_conversation().await.unwrap();
846
847        let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
848        let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
849        let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
850
851        assert_eq!(
852            store.count_messages_after(cid, MessageId(0)).await.unwrap(),
853            3
854        );
855        assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
856        assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
857    }
858
859    #[tokio::test]
860    async fn load_messages_range_basic() {
861        let store = test_store().await;
862        let cid = store.create_conversation().await.unwrap();
863
864        let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
865        let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
866        let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
867
868        let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
869        assert_eq!(msgs.len(), 2);
870        assert_eq!(msgs[0].0, msg_id2);
871        assert_eq!(msgs[0].2, "msg2");
872        assert_eq!(msgs[1].0, msg_id3);
873        assert_eq!(msgs[1].2, "msg3");
874    }
875
876    #[tokio::test]
877    async fn load_messages_range_respects_limit() {
878        let store = test_store().await;
879        let cid = store.create_conversation().await.unwrap();
880
881        store.save_message(cid, "user", "msg1").await.unwrap();
882        store.save_message(cid, "assistant", "msg2").await.unwrap();
883        store.save_message(cid, "user", "msg3").await.unwrap();
884
885        let msgs = store
886            .load_messages_range(cid, MessageId(0), 2)
887            .await
888            .unwrap();
889        assert_eq!(msgs.len(), 2);
890    }
891
892    #[tokio::test]
893    async fn keyword_search_basic() {
894        let store = test_store().await;
895        let cid = store.create_conversation().await.unwrap();
896
897        store
898            .save_message(cid, "user", "rust programming language")
899            .await
900            .unwrap();
901        store
902            .save_message(cid, "assistant", "python is great too")
903            .await
904            .unwrap();
905        store
906            .save_message(cid, "user", "I love rust and cargo")
907            .await
908            .unwrap();
909
910        let results = store.keyword_search("rust", 10, None).await.unwrap();
911        assert_eq!(results.len(), 2);
912        assert!(results.iter().all(|(_, score)| *score > 0.0));
913    }
914
915    #[tokio::test]
916    async fn keyword_search_with_conversation_filter() {
917        let store = test_store().await;
918        let cid1 = store.create_conversation().await.unwrap();
919        let cid2 = store.create_conversation().await.unwrap();
920
921        store
922            .save_message(cid1, "user", "hello world")
923            .await
924            .unwrap();
925        store
926            .save_message(cid2, "user", "hello universe")
927            .await
928            .unwrap();
929
930        let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
931        assert_eq!(results.len(), 1);
932    }
933
934    #[tokio::test]
935    async fn keyword_search_no_match() {
936        let store = test_store().await;
937        let cid = store.create_conversation().await.unwrap();
938
939        store
940            .save_message(cid, "user", "hello world")
941            .await
942            .unwrap();
943
944        let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
945        assert!(results.is_empty());
946    }
947
948    #[tokio::test]
949    async fn keyword_search_respects_limit() {
950        let store = test_store().await;
951        let cid = store.create_conversation().await.unwrap();
952
953        for i in 0..10 {
954            store
955                .save_message(cid, "user", &format!("test message {i}"))
956                .await
957                .unwrap();
958        }
959
960        let results = store.keyword_search("test", 3, None).await.unwrap();
961        assert_eq!(results.len(), 3);
962    }
963
964    #[test]
965    fn sanitize_fts5_query_strips_special_chars() {
966        assert_eq!(sanitize_fts5_query("skill-audit"), "skill audit");
967        assert_eq!(sanitize_fts5_query("hello, world"), "hello world");
968        assert_eq!(sanitize_fts5_query("a+b*c^d"), "a b c d");
969        assert_eq!(sanitize_fts5_query("  "), "");
970        assert_eq!(sanitize_fts5_query("rust programming"), "rust programming");
971    }
972
973    #[tokio::test]
974    async fn keyword_search_with_special_chars_does_not_error() {
975        let store = test_store().await;
976        let cid = store.create_conversation().await.unwrap();
977        store
978            .save_message(cid, "user", "skill audit info")
979            .await
980            .unwrap();
981        // query with comma and special chars — previously caused FTS5 syntax error
982        // result may be empty; important is that no error is returned
983        store
984            .keyword_search("skill-audit, confidence=0.1", 10, None)
985            .await
986            .unwrap();
987    }
988
989    #[tokio::test]
990    async fn save_message_with_metadata_stores_visibility() {
991        let store = test_store().await;
992        let cid = store.create_conversation().await.unwrap();
993
994        let id = store
995            .save_message_with_metadata(cid, "user", "hello", "[]", false, true)
996            .await
997            .unwrap();
998
999        let history = store.load_history(cid, 10).await.unwrap();
1000        assert_eq!(history.len(), 1);
1001        assert!(!history[0].metadata.agent_visible);
1002        assert!(history[0].metadata.user_visible);
1003        assert_eq!(id, MessageId(1));
1004    }
1005
1006    #[tokio::test]
1007    async fn load_history_filtered_by_agent_visible() {
1008        let store = test_store().await;
1009        let cid = store.create_conversation().await.unwrap();
1010
1011        store
1012            .save_message_with_metadata(cid, "user", "visible to agent", "[]", true, true)
1013            .await
1014            .unwrap();
1015        store
1016            .save_message_with_metadata(cid, "user", "user only", "[]", false, true)
1017            .await
1018            .unwrap();
1019
1020        let agent_msgs = store
1021            .load_history_filtered(cid, 50, Some(true), None)
1022            .await
1023            .unwrap();
1024        assert_eq!(agent_msgs.len(), 1);
1025        assert_eq!(agent_msgs[0].content, "visible to agent");
1026    }
1027
1028    #[tokio::test]
1029    async fn load_history_filtered_by_user_visible() {
1030        let store = test_store().await;
1031        let cid = store.create_conversation().await.unwrap();
1032
1033        store
1034            .save_message_with_metadata(cid, "system", "agent only summary", "[]", true, false)
1035            .await
1036            .unwrap();
1037        store
1038            .save_message_with_metadata(cid, "user", "user sees this", "[]", true, true)
1039            .await
1040            .unwrap();
1041
1042        let user_msgs = store
1043            .load_history_filtered(cid, 50, None, Some(true))
1044            .await
1045            .unwrap();
1046        assert_eq!(user_msgs.len(), 1);
1047        assert_eq!(user_msgs[0].content, "user sees this");
1048    }
1049
1050    #[tokio::test]
1051    async fn load_history_filtered_no_filter_returns_all() {
1052        let store = test_store().await;
1053        let cid = store.create_conversation().await.unwrap();
1054
1055        store
1056            .save_message_with_metadata(cid, "user", "msg1", "[]", true, false)
1057            .await
1058            .unwrap();
1059        store
1060            .save_message_with_metadata(cid, "user", "msg2", "[]", false, true)
1061            .await
1062            .unwrap();
1063
1064        let all_msgs = store
1065            .load_history_filtered(cid, 50, None, None)
1066            .await
1067            .unwrap();
1068        assert_eq!(all_msgs.len(), 2);
1069    }
1070
1071    #[tokio::test]
1072    async fn replace_conversation_marks_originals_and_inserts_summary() {
1073        let store = test_store().await;
1074        let cid = store.create_conversation().await.unwrap();
1075
1076        let id1 = store.save_message(cid, "user", "first").await.unwrap();
1077        let id2 = store
1078            .save_message(cid, "assistant", "second")
1079            .await
1080            .unwrap();
1081        let id3 = store.save_message(cid, "user", "third").await.unwrap();
1082
1083        let summary_id = store
1084            .replace_conversation(cid, id1..=id2, "system", "summary text")
1085            .await
1086            .unwrap();
1087
1088        // Original messages should be user_only
1089        let all = store.load_history(cid, 50).await.unwrap();
1090        // id1 and id2 marked agent_visible=false, id3 untouched, summary inserted
1091        let by_id1 = all.iter().find(|m| m.content == "first").unwrap();
1092        assert!(!by_id1.metadata.agent_visible);
1093        assert!(by_id1.metadata.user_visible);
1094
1095        let by_id2 = all.iter().find(|m| m.content == "second").unwrap();
1096        assert!(!by_id2.metadata.agent_visible);
1097
1098        let by_id3 = all.iter().find(|m| m.content == "third").unwrap();
1099        assert!(by_id3.metadata.agent_visible);
1100
1101        // Summary is agent_only (agent_visible=1, user_visible=0)
1102        let summary = all.iter().find(|m| m.content == "summary text").unwrap();
1103        assert!(summary.metadata.agent_visible);
1104        assert!(!summary.metadata.user_visible);
1105        assert!(summary_id > id3);
1106    }
1107
1108    #[tokio::test]
1109    async fn oldest_message_ids_returns_in_order() {
1110        let store = test_store().await;
1111        let cid = store.create_conversation().await.unwrap();
1112
1113        let id1 = store.save_message(cid, "user", "a").await.unwrap();
1114        let id2 = store.save_message(cid, "assistant", "b").await.unwrap();
1115        let id3 = store.save_message(cid, "user", "c").await.unwrap();
1116
1117        let ids = store.oldest_message_ids(cid, 2).await.unwrap();
1118        assert_eq!(ids, vec![id1, id2]);
1119        assert!(ids[0] < ids[1]);
1120
1121        let all_ids = store.oldest_message_ids(cid, 10).await.unwrap();
1122        assert_eq!(all_ids, vec![id1, id2, id3]);
1123    }
1124
1125    #[tokio::test]
1126    async fn message_metadata_default_both_visible() {
1127        let store = test_store().await;
1128        let cid = store.create_conversation().await.unwrap();
1129
1130        store.save_message(cid, "user", "normal").await.unwrap();
1131
1132        let history = store.load_history(cid, 10).await.unwrap();
1133        assert!(history[0].metadata.agent_visible);
1134        assert!(history[0].metadata.user_visible);
1135        assert!(history[0].metadata.compacted_at.is_none());
1136    }
1137
1138    #[tokio::test]
1139    async fn load_history_empty_parts_json_fast_path() {
1140        let store = test_store().await;
1141        let cid = store.create_conversation().await.unwrap();
1142
1143        store
1144            .save_message_with_parts(cid, "user", "hello", "[]")
1145            .await
1146            .unwrap();
1147
1148        let history = store.load_history(cid, 10).await.unwrap();
1149        assert_eq!(history.len(), 1);
1150        assert!(
1151            history[0].parts.is_empty(),
1152            "\"[]\" fast-path must yield empty parts Vec"
1153        );
1154    }
1155
1156    #[tokio::test]
1157    async fn load_history_non_empty_parts_json_parsed() {
1158        let store = test_store().await;
1159        let cid = store.create_conversation().await.unwrap();
1160
1161        let parts_json = serde_json::to_string(&vec![MessagePart::ToolResult {
1162            tool_use_id: "t1".into(),
1163            content: "result".into(),
1164            is_error: false,
1165        }])
1166        .unwrap();
1167
1168        store
1169            .save_message_with_parts(cid, "user", "hello", &parts_json)
1170            .await
1171            .unwrap();
1172
1173        let history = store.load_history(cid, 10).await.unwrap();
1174        assert_eq!(history.len(), 1);
1175        assert_eq!(history[0].parts.len(), 1);
1176        assert!(
1177            matches!(&history[0].parts[0], MessagePart::ToolResult { content, .. } if content == "result")
1178        );
1179    }
1180
1181    #[tokio::test]
1182    async fn message_by_id_empty_parts_json_fast_path() {
1183        let store = test_store().await;
1184        let cid = store.create_conversation().await.unwrap();
1185
1186        let id = store
1187            .save_message_with_parts(cid, "user", "msg", "[]")
1188            .await
1189            .unwrap();
1190
1191        let msg = store.message_by_id(id).await.unwrap().unwrap();
1192        assert!(
1193            msg.parts.is_empty(),
1194            "\"[]\" fast-path must yield empty parts Vec in message_by_id"
1195        );
1196    }
1197
1198    #[tokio::test]
1199    async fn messages_by_ids_empty_parts_json_fast_path() {
1200        let store = test_store().await;
1201        let cid = store.create_conversation().await.unwrap();
1202
1203        let id = store
1204            .save_message_with_parts(cid, "user", "msg", "[]")
1205            .await
1206            .unwrap();
1207
1208        let results = store.messages_by_ids(&[id]).await.unwrap();
1209        assert_eq!(results.len(), 1);
1210        assert!(
1211            results[0].1.parts.is_empty(),
1212            "\"[]\" fast-path must yield empty parts Vec in messages_by_ids"
1213        );
1214    }
1215
1216    #[tokio::test]
1217    async fn load_history_filtered_empty_parts_json_fast_path() {
1218        let store = test_store().await;
1219        let cid = store.create_conversation().await.unwrap();
1220
1221        store
1222            .save_message_with_metadata(cid, "user", "msg", "[]", true, true)
1223            .await
1224            .unwrap();
1225
1226        let msgs = store
1227            .load_history_filtered(cid, 10, Some(true), None)
1228            .await
1229            .unwrap();
1230        assert_eq!(msgs.len(), 1);
1231        assert!(
1232            msgs[0].parts.is_empty(),
1233            "\"[]\" fast-path must yield empty parts Vec in load_history_filtered"
1234        );
1235    }
1236}