Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15/// Distinguishes regular messages from summaries when storing embeddings.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18    Regular,
19    Summary,
20}
21
22impl MessageKind {
23    #[must_use]
24    pub fn is_summary(self) -> bool {
25        matches!(self, Self::Summary)
26    }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31/// Ensure a Qdrant collection exists with cosine distance vectors.
32///
33/// Idempotent: no-op if the collection already exists.
34///
35/// # Errors
36///
37/// Returns an error if Qdrant cannot be reached or collection creation fails.
38pub async fn ensure_qdrant_collection(
39    ops: &QdrantOps,
40    collection: &str,
41    vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43    ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47    ops: Box<dyn VectorStore>,
48    collection: String,
49    pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("EmbeddingStore")
55            .field("collection", &self.collection)
56            .finish_non_exhaustive()
57    }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62    pub conversation_id: Option<ConversationId>,
63    pub role: Option<String>,
64    /// Optional category filter for category-aware memory (#2428).
65    /// When `Some`, Qdrant search is restricted to vectors with a matching `category` payload.
66    pub category: Option<String>,
67}
68
69#[derive(Debug)]
70pub struct SearchResult {
71    pub message_id: MessageId,
72    pub conversation_id: ConversationId,
73    pub score: f32,
74}
75
76impl EmbeddingStore {
77    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
78    ///
79    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
80    /// table (which must already exist via sqlx migrations).
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if the Qdrant client cannot be created.
85    pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
86        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
87
88        Ok(Self {
89            ops: Box::new(ops),
90            collection: COLLECTION_NAME.into(),
91            pool,
92        })
93    }
94
95    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
96    ///
97    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
98    #[must_use]
99    pub fn new_sqlite(pool: DbPool) -> Self {
100        let ops = DbVectorStore::new(pool.clone());
101        Self {
102            ops: Box::new(ops),
103            collection: COLLECTION_NAME.into(),
104            pool,
105        }
106    }
107
108    #[must_use]
109    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
110        Self {
111            ops: store,
112            collection: COLLECTION_NAME.into(),
113            pool,
114        }
115    }
116
117    pub async fn health_check(&self) -> bool {
118        self.ops.health_check().await.unwrap_or(false)
119    }
120
121    /// Ensure the collection exists in Qdrant with the given vector size.
122    ///
123    /// Idempotent: no-op if the collection already exists.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if Qdrant cannot be reached or collection creation fails.
128    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
129        self.ops
130            .ensure_collection(&self.collection, vector_size)
131            .await?;
132        // Create keyword indexes for the fields used in filtered recall so Qdrant can satisfy
133        // filter conditions in O(log n) instead of scanning all payload documents.
134        self.ops
135            .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
136            .await?;
137        Ok(())
138    }
139
140    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
141    ///
142    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
143    /// alongside the standard fields. This allows filtering and scoring by tool context
144    /// without corrupting the embedding vector with text prefixes.
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
149    #[allow(clippy::too_many_arguments)]
150    pub async fn store_with_tool_context(
151        &self,
152        message_id: MessageId,
153        conversation_id: ConversationId,
154        role: &str,
155        vector: Vec<f32>,
156        kind: MessageKind,
157        model: &str,
158        chunk_index: u32,
159        tool_name: &str,
160        exit_code: Option<i32>,
161        timestamp: Option<&str>,
162    ) -> Result<String, MemoryError> {
163        let point_id = uuid::Uuid::new_v4().to_string();
164        let dimensions = i64::try_from(vector.len())?;
165
166        let mut payload = std::collections::HashMap::from([
167            ("message_id".to_owned(), serde_json::json!(message_id.0)),
168            (
169                "conversation_id".to_owned(),
170                serde_json::json!(conversation_id.0),
171            ),
172            ("role".to_owned(), serde_json::json!(role)),
173            (
174                "is_summary".to_owned(),
175                serde_json::json!(kind.is_summary()),
176            ),
177            ("tool_name".to_owned(), serde_json::json!(tool_name)),
178        ]);
179        if let Some(code) = exit_code {
180            payload.insert("exit_code".to_owned(), serde_json::json!(code));
181        }
182        if let Some(ts) = timestamp {
183            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
184        }
185
186        let point = VectorPoint {
187            id: point_id.clone(),
188            vector,
189            payload,
190        };
191
192        self.ops.upsert(&self.collection, vec![point]).await?;
193
194        let chunk_index_i64 = i64::from(chunk_index);
195        zeph_db::query(sql!(
196            "INSERT INTO embeddings_metadata \
197             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
198             VALUES (?, ?, ?, ?, ?) \
199             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
200             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
201        ))
202        .bind(message_id)
203        .bind(chunk_index_i64)
204        .bind(&point_id)
205        .bind(dimensions)
206        .bind(model)
207        .execute(&self.pool)
208        .await?;
209
210        Ok(point_id)
211    }
212
213    /// Store a vector in Qdrant and persist metadata to `SQLite`.
214    ///
215    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
216    /// when a long message is split into multiple embeddings.
217    ///
218    /// Returns the UUID of the newly created Qdrant point.
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
223    #[allow(clippy::too_many_arguments)]
224    pub async fn store(
225        &self,
226        message_id: MessageId,
227        conversation_id: ConversationId,
228        role: &str,
229        vector: Vec<f32>,
230        kind: MessageKind,
231        model: &str,
232        chunk_index: u32,
233    ) -> Result<String, MemoryError> {
234        let point_id = uuid::Uuid::new_v4().to_string();
235        let dimensions = i64::try_from(vector.len())?;
236
237        let payload = std::collections::HashMap::from([
238            ("message_id".to_owned(), serde_json::json!(message_id.0)),
239            (
240                "conversation_id".to_owned(),
241                serde_json::json!(conversation_id.0),
242            ),
243            ("role".to_owned(), serde_json::json!(role)),
244            (
245                "is_summary".to_owned(),
246                serde_json::json!(kind.is_summary()),
247            ),
248        ]);
249
250        let point = VectorPoint {
251            id: point_id.clone(),
252            vector,
253            payload,
254        };
255
256        self.ops.upsert(&self.collection, vec![point]).await?;
257
258        let chunk_index_i64 = i64::from(chunk_index);
259        zeph_db::query(sql!(
260            "INSERT INTO embeddings_metadata \
261             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
262             VALUES (?, ?, ?, ?, ?) \
263             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
264             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
265        ))
266        .bind(message_id)
267        .bind(chunk_index_i64)
268        .bind(&point_id)
269        .bind(dimensions)
270        .bind(model)
271        .execute(&self.pool)
272        .await?;
273
274        Ok(point_id)
275    }
276
277    /// Store a vector with an optional category tag in the Qdrant payload.
278    ///
279    /// Identical to [`store`] but adds a `category` field to the payload when provided.
280    /// Used by category-aware memory (#2428) to enable category-filtered recall.
281    ///
282    /// Note: when `category` is `None` no `category` field is written to the Qdrant payload.
283    /// Memories stored before category-aware recall was enabled therefore won't match a
284    /// category filter — this is intentional (no silent false-positives), but a backfill
285    /// pass is needed if retrospective categorization is desired.
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
290    #[allow(clippy::too_many_arguments)]
291    pub async fn store_with_category(
292        &self,
293        message_id: MessageId,
294        conversation_id: ConversationId,
295        role: &str,
296        vector: Vec<f32>,
297        kind: MessageKind,
298        model: &str,
299        chunk_index: u32,
300        category: Option<&str>,
301    ) -> Result<String, MemoryError> {
302        let point_id = uuid::Uuid::new_v4().to_string();
303        let dimensions = i64::try_from(vector.len())?;
304
305        let mut payload = std::collections::HashMap::from([
306            ("message_id".to_owned(), serde_json::json!(message_id.0)),
307            (
308                "conversation_id".to_owned(),
309                serde_json::json!(conversation_id.0),
310            ),
311            ("role".to_owned(), serde_json::json!(role)),
312            (
313                "is_summary".to_owned(),
314                serde_json::json!(kind.is_summary()),
315            ),
316        ]);
317        if let Some(cat) = category {
318            payload.insert("category".to_owned(), serde_json::json!(cat));
319        }
320
321        let point = VectorPoint {
322            id: point_id.clone(),
323            vector,
324            payload,
325        };
326
327        self.ops.upsert(&self.collection, vec![point]).await?;
328
329        let chunk_index_i64 = i64::from(chunk_index);
330        zeph_db::query(sql!(
331            "INSERT INTO embeddings_metadata \
332             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
333             VALUES (?, ?, ?, ?, ?) \
334             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
335             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
336        ))
337        .bind(message_id)
338        .bind(chunk_index_i64)
339        .bind(&point_id)
340        .bind(dimensions)
341        .bind(model)
342        .execute(&self.pool)
343        .await?;
344
345        Ok(point_id)
346    }
347
348    /// Search for similar vectors in Qdrant, returning up to `limit` results.
349    ///
350    /// # Errors
351    ///
352    /// Returns an error if the Qdrant search fails.
353    pub async fn search(
354        &self,
355        query_vector: &[f32],
356        limit: usize,
357        filter: Option<SearchFilter>,
358    ) -> Result<Vec<SearchResult>, MemoryError> {
359        let limit_u64 = u64::try_from(limit)?;
360
361        let vector_filter = filter.as_ref().and_then(|f| {
362            let mut must = Vec::new();
363            if let Some(cid) = f.conversation_id {
364                must.push(FieldCondition {
365                    field: "conversation_id".into(),
366                    value: FieldValue::Integer(cid.0),
367                });
368            }
369            if let Some(ref role) = f.role {
370                must.push(FieldCondition {
371                    field: "role".into(),
372                    value: FieldValue::Text(role.clone()),
373                });
374            }
375            if let Some(ref category) = f.category {
376                must.push(FieldCondition {
377                    field: "category".into(),
378                    value: FieldValue::Text(category.clone()),
379                });
380            }
381            if must.is_empty() {
382                None
383            } else {
384                Some(VectorFilter {
385                    must,
386                    must_not: vec![],
387                })
388            }
389        });
390
391        let results = self
392            .ops
393            .search(
394                &self.collection,
395                query_vector.to_vec(),
396                limit_u64,
397                vector_filter,
398            )
399            .await?;
400
401        // Deduplicate by message_id, keeping the chunk with the highest score.
402        // A single message may produce multiple Qdrant points (one per chunk).
403        let mut best: std::collections::HashMap<MessageId, SearchResult> =
404            std::collections::HashMap::new();
405        for point in results {
406            let Some(message_id) = point
407                .payload
408                .get("message_id")
409                .and_then(serde_json::Value::as_i64)
410            else {
411                continue;
412            };
413            let Some(conversation_id) = point
414                .payload
415                .get("conversation_id")
416                .and_then(serde_json::Value::as_i64)
417            else {
418                continue;
419            };
420            let message_id = MessageId(message_id);
421            let entry = best.entry(message_id).or_insert(SearchResult {
422                message_id,
423                conversation_id: ConversationId(conversation_id),
424                score: f32::NEG_INFINITY,
425            });
426            if point.score > entry.score {
427                entry.score = point.score;
428            }
429        }
430
431        let mut search_results: Vec<SearchResult> = best.into_values().collect();
432        search_results.sort_by(|a, b| {
433            b.score
434                .partial_cmp(&a.score)
435                .unwrap_or(std::cmp::Ordering::Equal)
436        });
437        search_results.truncate(limit);
438
439        Ok(search_results)
440    }
441
442    /// Check whether a named collection exists in the vector store.
443    ///
444    /// # Errors
445    ///
446    /// Returns an error if the store backend cannot be reached.
447    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
448        self.ops.collection_exists(name).await.map_err(Into::into)
449    }
450
451    /// Ensure a named collection exists in Qdrant with the given vector size.
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if Qdrant cannot be reached or collection creation fails.
456    pub async fn ensure_named_collection(
457        &self,
458        name: &str,
459        vector_size: u64,
460    ) -> Result<(), MemoryError> {
461        self.ops.ensure_collection(name, vector_size).await?;
462        Ok(())
463    }
464
465    /// Store a vector in a named Qdrant collection with arbitrary payload.
466    ///
467    /// Returns the UUID of the newly created point.
468    ///
469    /// # Errors
470    ///
471    /// Returns an error if the Qdrant upsert fails.
472    pub async fn store_to_collection(
473        &self,
474        collection: &str,
475        payload: serde_json::Value,
476        vector: Vec<f32>,
477    ) -> Result<String, MemoryError> {
478        let point_id = uuid::Uuid::new_v4().to_string();
479        let payload_map: std::collections::HashMap<String, serde_json::Value> =
480            serde_json::from_value(payload)?;
481        let point = VectorPoint {
482            id: point_id.clone(),
483            vector,
484            payload: payload_map,
485        };
486        self.ops.upsert(collection, vec![point]).await?;
487        Ok(point_id)
488    }
489
490    /// Upsert a vector into a named collection, reusing an existing point ID.
491    ///
492    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
493    ///
494    /// # Errors
495    ///
496    /// Returns an error if the Qdrant upsert fails.
497    pub async fn upsert_to_collection(
498        &self,
499        collection: &str,
500        point_id: &str,
501        payload: serde_json::Value,
502        vector: Vec<f32>,
503    ) -> Result<(), MemoryError> {
504        let payload_map: std::collections::HashMap<String, serde_json::Value> =
505            serde_json::from_value(payload)?;
506        let point = VectorPoint {
507            id: point_id.to_owned(),
508            vector,
509            payload: payload_map,
510        };
511        self.ops.upsert(collection, vec![point]).await?;
512        Ok(())
513    }
514
515    /// Search a named Qdrant collection, returning scored points with payloads.
516    ///
517    /// # Errors
518    ///
519    /// Returns an error if the Qdrant search fails.
520    pub async fn search_collection(
521        &self,
522        collection: &str,
523        query_vector: &[f32],
524        limit: usize,
525        filter: Option<VectorFilter>,
526    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
527        let limit_u64 = u64::try_from(limit)?;
528        let results = self
529            .ops
530            .search(collection, query_vector.to_vec(), limit_u64, filter)
531            .await?;
532        Ok(results)
533    }
534
535    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
536    ///
537    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
538    ///
539    /// # Errors
540    ///
541    /// Returns an error if the `SQLite` query fails.
542    pub async fn get_vectors(
543        &self,
544        ids: &[MessageId],
545    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
546        if ids.is_empty() {
547            return Ok(std::collections::HashMap::new());
548        }
549
550        let placeholders = zeph_db::placeholder_list(1, ids.len());
551        let query = format!(
552            "SELECT em.message_id, vp.vector \
553             FROM embeddings_metadata em \
554             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
555             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
556        );
557        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
558        for &id in ids {
559            q = q.bind(id);
560        }
561
562        let rows = q.fetch_all(&self.pool).await?;
563
564        let map = rows
565            .into_iter()
566            .filter_map(|(msg_id, blob)| {
567                if blob.len() % 4 != 0 {
568                    return None;
569                }
570                let vec: Vec<f32> = blob
571                    .chunks_exact(4)
572                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
573                    .collect();
574                Some((msg_id, vec))
575            })
576            .collect();
577
578        Ok(map)
579    }
580
581    /// Check whether an embedding already exists for the given message ID.
582    ///
583    /// # Errors
584    ///
585    /// Returns an error if the `SQLite` query fails.
586    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
587        let row: (i64,) = zeph_db::query_as(sql!(
588            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
589        ))
590        .bind(message_id)
591        .fetch_one(&self.pool)
592        .await?;
593
594        Ok(row.0 > 0)
595    }
596
597    /// Check whether a Qdrant embedding for `entity_name` is current by comparing the
598    /// Qdrant-side epoch against the epoch stored in `graph_entities`.
599    ///
600    /// Returns `true` if the Qdrant embedding is up-to-date or if the entity no longer
601    /// exists in `SQLite` (embedding should be cleaned up separately).
602    ///
603    /// # Errors
604    ///
605    /// Returns an error if the `SQLite` query fails.
606    pub async fn is_epoch_current(
607        &self,
608        entity_name: &str,
609        qdrant_epoch: u64,
610    ) -> Result<bool, MemoryError> {
611        let row: Option<(i64,)> = zeph_db::query_as(sql!(
612            "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
613        ))
614        .bind(entity_name)
615        .fetch_optional(&self.pool)
616        .await?;
617
618        match row {
619            None => Ok(true), // entity deleted; Qdrant point is orphaned, not stale per epoch
620            Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
621        }
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use crate::in_memory_store::InMemoryVectorStore;
629    use crate::store::SqliteStore;
630
631    async fn setup() -> (SqliteStore, DbPool) {
632        let store = SqliteStore::new(":memory:").await.unwrap();
633        let pool = store.pool().clone();
634        (store, pool)
635    }
636
637    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
638        let sqlite = SqliteStore::new(":memory:").await.unwrap();
639        let pool = sqlite.pool().clone();
640        let mem_store = Box::new(InMemoryVectorStore::new());
641        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
642        // Create collection first
643        embedding_store.ensure_collection(4).await.unwrap();
644        (embedding_store, sqlite)
645    }
646
647    #[tokio::test]
648    async fn has_embedding_returns_false_when_none() {
649        let (_store, pool) = setup().await;
650
651        let row: (i64,) = zeph_db::query_as(sql!(
652            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
653        ))
654        .bind(999_i64)
655        .fetch_one(&pool)
656        .await
657        .unwrap();
658
659        assert_eq!(row.0, 0);
660    }
661
662    #[tokio::test]
663    async fn insert_and_query_embeddings_metadata() {
664        let (sqlite, pool) = setup().await;
665        let cid = sqlite.create_conversation().await.unwrap();
666        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
667
668        let point_id = uuid::Uuid::new_v4().to_string();
669        zeph_db::query(sql!(
670            "INSERT INTO embeddings_metadata \
671             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
672             VALUES (?, ?, ?, ?, ?)"
673        ))
674        .bind(msg_id)
675        .bind(0_i64)
676        .bind(&point_id)
677        .bind(768_i64)
678        .bind("qwen3-embedding")
679        .execute(&pool)
680        .await
681        .unwrap();
682
683        let row: (i64,) = zeph_db::query_as(sql!(
684            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
685        ))
686        .bind(msg_id)
687        .fetch_one(&pool)
688        .await
689        .unwrap();
690        assert_eq!(row.0, 1);
691    }
692
693    #[tokio::test]
694    async fn embedding_store_search_empty_returns_empty() {
695        let (store, _sqlite) = setup_with_store().await;
696        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
697        assert!(results.is_empty());
698    }
699
700    #[tokio::test]
701    async fn embedding_store_store_and_search() {
702        let (store, sqlite) = setup_with_store().await;
703        let cid = sqlite.create_conversation().await.unwrap();
704        let msg_id = sqlite
705            .save_message(cid, "user", "test message")
706            .await
707            .unwrap();
708
709        store
710            .store(
711                msg_id,
712                cid,
713                "user",
714                vec![1.0, 0.0, 0.0, 0.0],
715                MessageKind::Regular,
716                "test-model",
717                0,
718            )
719            .await
720            .unwrap();
721
722        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
723        assert_eq!(results.len(), 1);
724        assert_eq!(results[0].message_id, msg_id);
725        assert_eq!(results[0].conversation_id, cid);
726        assert!((results[0].score - 1.0).abs() < 0.001);
727    }
728
729    #[tokio::test]
730    async fn embedding_store_has_embedding_false_for_unknown() {
731        let (store, sqlite) = setup_with_store().await;
732        let cid = sqlite.create_conversation().await.unwrap();
733        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
734        assert!(!store.has_embedding(msg_id).await.unwrap());
735    }
736
737    #[tokio::test]
738    async fn embedding_store_has_embedding_true_after_store() {
739        let (store, sqlite) = setup_with_store().await;
740        let cid = sqlite.create_conversation().await.unwrap();
741        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
742
743        store
744            .store(
745                msg_id,
746                cid,
747                "user",
748                vec![0.0, 1.0, 0.0, 0.0],
749                MessageKind::Regular,
750                "test-model",
751                0,
752            )
753            .await
754            .unwrap();
755
756        assert!(store.has_embedding(msg_id).await.unwrap());
757    }
758
759    #[tokio::test]
760    async fn embedding_store_search_with_conversation_filter() {
761        let (store, sqlite) = setup_with_store().await;
762        let cid1 = sqlite.create_conversation().await.unwrap();
763        let cid2 = sqlite.create_conversation().await.unwrap();
764        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
765        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
766
767        store
768            .store(
769                msg1,
770                cid1,
771                "user",
772                vec![1.0, 0.0, 0.0, 0.0],
773                MessageKind::Regular,
774                "m",
775                0,
776            )
777            .await
778            .unwrap();
779        store
780            .store(
781                msg2,
782                cid2,
783                "user",
784                vec![1.0, 0.0, 0.0, 0.0],
785                MessageKind::Regular,
786                "m",
787                0,
788            )
789            .await
790            .unwrap();
791
792        let results = store
793            .search(
794                &[1.0, 0.0, 0.0, 0.0],
795                10,
796                Some(SearchFilter {
797                    conversation_id: Some(cid1),
798                    role: None,
799                    category: None,
800                }),
801            )
802            .await
803            .unwrap();
804        assert_eq!(results.len(), 1);
805        assert_eq!(results[0].conversation_id, cid1);
806    }
807
808    #[tokio::test]
809    async fn unique_constraint_on_message_chunk_and_model() {
810        let (sqlite, pool) = setup().await;
811        let cid = sqlite.create_conversation().await.unwrap();
812        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
813
814        let point_id1 = uuid::Uuid::new_v4().to_string();
815        zeph_db::query(sql!(
816            "INSERT INTO embeddings_metadata \
817             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
818             VALUES (?, ?, ?, ?, ?)"
819        ))
820        .bind(msg_id)
821        .bind(0_i64)
822        .bind(&point_id1)
823        .bind(768_i64)
824        .bind("qwen3-embedding")
825        .execute(&pool)
826        .await
827        .unwrap();
828
829        // Same (message_id, chunk_index, model) — must fail.
830        let point_id2 = uuid::Uuid::new_v4().to_string();
831        let result = zeph_db::query(sql!(
832            "INSERT INTO embeddings_metadata \
833             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
834             VALUES (?, ?, ?, ?, ?)"
835        ))
836        .bind(msg_id)
837        .bind(0_i64)
838        .bind(&point_id2)
839        .bind(768_i64)
840        .bind("qwen3-embedding")
841        .execute(&pool)
842        .await;
843        assert!(result.is_err());
844
845        // Different chunk_index — must succeed.
846        let point_id3 = uuid::Uuid::new_v4().to_string();
847        zeph_db::query(sql!(
848            "INSERT INTO embeddings_metadata \
849             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
850             VALUES (?, ?, ?, ?, ?)"
851        ))
852        .bind(msg_id)
853        .bind(1_i64)
854        .bind(&point_id3)
855        .bind(768_i64)
856        .bind("qwen3-embedding")
857        .execute(&pool)
858        .await
859        .unwrap();
860    }
861}