Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Qdrant-backed embedding store for message vector search.
5//!
6//! [`EmbeddingStore`] owns a [`VectorStore`] implementation (Qdrant in production,
7//! [`crate::db_vector_store::DbVectorStore`] in tests) and exposes typed `embed` /
8//! `search` / `delete` operations used by [`crate::semantic::SemanticMemory`].
9//!
10//! Message vectors are stored in the `zeph_conversations` Qdrant collection with a
11//! payload that includes `message_id`, `conversation_id`, `role`, and `category`.
12
13pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24/// Distinguishes regular messages from summaries when storing embeddings.
25///
26/// The kind is encoded in the Qdrant payload so search filters can restrict
27/// results to one category.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30    /// A normal conversation message.
31    Regular,
32    /// A compression summary generated by the summarization subsystem.
33    Summary,
34}
35
36impl MessageKind {
37    #[must_use]
38    pub fn is_summary(self) -> bool {
39        matches!(self, Self::Summary)
40    }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45/// Ensure a Qdrant collection exists with cosine distance vectors.
46///
47/// Idempotent: no-op if the collection already exists.
48///
49/// # Errors
50///
51/// Returns an error if Qdrant cannot be reached or collection creation fails.
52pub async fn ensure_qdrant_collection(
53    ops: &QdrantOps,
54    collection: &str,
55    vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57    ops.ensure_collection(collection, vector_size).await
58}
59
60/// Typed wrapper over a [`VectorStore`] backend for conversation message embeddings.
61///
62/// Constructed via [`EmbeddingStore::new`] (Qdrant URL) or
63/// [`EmbeddingStore::with_store`] (custom backend for testing).
64pub struct EmbeddingStore {
65    ops: Box<dyn VectorStore>,
66    collection: String,
67    pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("EmbeddingStore")
73            .field("collection", &self.collection)
74            .finish_non_exhaustive()
75    }
76}
77
78/// Optional filters applied to a vector similarity search.
79#[derive(Debug)]
80pub struct SearchFilter {
81    /// Restrict results to a single conversation. `None` searches across all conversations.
82    pub conversation_id: Option<ConversationId>,
83    /// Restrict by message role (`"user"` / `"assistant"`). `None` returns all roles.
84    pub role: Option<String>,
85    /// Restrict by category payload field (category-aware memory, #2428).
86    /// When `Some`, Qdrant search is restricted to vectors with a matching `category` payload.
87    pub category: Option<String>,
88}
89
90/// A single result returned by [`EmbeddingStore::search`].
91#[derive(Debug)]
92pub struct SearchResult {
93    /// Database row ID of the matching message.
94    pub message_id: MessageId,
95    /// Conversation the message belongs to.
96    pub conversation_id: ConversationId,
97    /// Cosine similarity score in `[0, 1]`.
98    pub score: f32,
99}
100
101impl EmbeddingStore {
102    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
103    ///
104    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
105    /// table (which must already exist via sqlx migrations).
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the Qdrant client cannot be created.
110    pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
111        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
112
113        Ok(Self {
114            ops: Box::new(ops),
115            collection: COLLECTION_NAME.into(),
116            pool,
117        })
118    }
119
120    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
121    ///
122    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
123    #[must_use]
124    pub fn new_sqlite(pool: DbPool) -> Self {
125        let ops = DbVectorStore::new(pool.clone());
126        Self {
127            ops: Box::new(ops),
128            collection: COLLECTION_NAME.into(),
129            pool,
130        }
131    }
132
133    #[must_use]
134    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
135        Self {
136            ops: store,
137            collection: COLLECTION_NAME.into(),
138            pool,
139        }
140    }
141
142    pub async fn health_check(&self) -> bool {
143        self.ops.health_check().await.unwrap_or(false)
144    }
145
146    /// Ensure the collection exists in Qdrant with the given vector size.
147    ///
148    /// Idempotent: no-op if the collection already exists.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if Qdrant cannot be reached or collection creation fails.
153    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
154        self.ops
155            .ensure_collection(&self.collection, vector_size)
156            .await?;
157        // Create keyword indexes for the fields used in filtered recall so Qdrant can satisfy
158        // filter conditions in O(log n) instead of scanning all payload documents.
159        self.ops
160            .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
161            .await?;
162        Ok(())
163    }
164
165    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
166    ///
167    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
168    /// alongside the standard fields. This allows filtering and scoring by tool context
169    /// without corrupting the embedding vector with text prefixes.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
174    #[allow(clippy::too_many_arguments)]
175    pub async fn store_with_tool_context(
176        &self,
177        message_id: MessageId,
178        conversation_id: ConversationId,
179        role: &str,
180        vector: Vec<f32>,
181        kind: MessageKind,
182        model: &str,
183        chunk_index: u32,
184        tool_name: &str,
185        exit_code: Option<i32>,
186        timestamp: Option<&str>,
187    ) -> Result<String, MemoryError> {
188        let point_id = uuid::Uuid::new_v4().to_string();
189        let dimensions = i64::try_from(vector.len())?;
190
191        let mut payload = std::collections::HashMap::from([
192            ("message_id".to_owned(), serde_json::json!(message_id.0)),
193            (
194                "conversation_id".to_owned(),
195                serde_json::json!(conversation_id.0),
196            ),
197            ("role".to_owned(), serde_json::json!(role)),
198            (
199                "is_summary".to_owned(),
200                serde_json::json!(kind.is_summary()),
201            ),
202            ("tool_name".to_owned(), serde_json::json!(tool_name)),
203        ]);
204        if let Some(code) = exit_code {
205            payload.insert("exit_code".to_owned(), serde_json::json!(code));
206        }
207        if let Some(ts) = timestamp {
208            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
209        }
210
211        let point = VectorPoint {
212            id: point_id.clone(),
213            vector,
214            payload,
215        };
216
217        self.ops.upsert(&self.collection, vec![point]).await?;
218
219        let chunk_index_i64 = i64::from(chunk_index);
220        zeph_db::query(sql!(
221            "INSERT INTO embeddings_metadata \
222             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
223             VALUES (?, ?, ?, ?, ?) \
224             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
225             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
226        ))
227        .bind(message_id)
228        .bind(chunk_index_i64)
229        .bind(&point_id)
230        .bind(dimensions)
231        .bind(model)
232        .execute(&self.pool)
233        .await?;
234
235        Ok(point_id)
236    }
237
238    /// Store a vector in Qdrant and persist metadata to `SQLite`.
239    ///
240    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
241    /// when a long message is split into multiple embeddings.
242    ///
243    /// Returns the UUID of the newly created Qdrant point.
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
248    #[allow(clippy::too_many_arguments)]
249    pub async fn store(
250        &self,
251        message_id: MessageId,
252        conversation_id: ConversationId,
253        role: &str,
254        vector: Vec<f32>,
255        kind: MessageKind,
256        model: &str,
257        chunk_index: u32,
258    ) -> Result<String, MemoryError> {
259        let point_id = uuid::Uuid::new_v4().to_string();
260        let dimensions = i64::try_from(vector.len())?;
261
262        let payload = std::collections::HashMap::from([
263            ("message_id".to_owned(), serde_json::json!(message_id.0)),
264            (
265                "conversation_id".to_owned(),
266                serde_json::json!(conversation_id.0),
267            ),
268            ("role".to_owned(), serde_json::json!(role)),
269            (
270                "is_summary".to_owned(),
271                serde_json::json!(kind.is_summary()),
272            ),
273        ]);
274
275        let point = VectorPoint {
276            id: point_id.clone(),
277            vector,
278            payload,
279        };
280
281        self.ops.upsert(&self.collection, vec![point]).await?;
282
283        let chunk_index_i64 = i64::from(chunk_index);
284        zeph_db::query(sql!(
285            "INSERT INTO embeddings_metadata \
286             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
287             VALUES (?, ?, ?, ?, ?) \
288             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
289             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
290        ))
291        .bind(message_id)
292        .bind(chunk_index_i64)
293        .bind(&point_id)
294        .bind(dimensions)
295        .bind(model)
296        .execute(&self.pool)
297        .await?;
298
299        Ok(point_id)
300    }
301
302    /// Store a vector with an optional category tag in the Qdrant payload.
303    ///
304    /// Identical to [`Self::store`] but adds a `category` field to the payload when provided.
305    /// Used by category-aware memory (#2428) to enable category-filtered recall.
306    ///
307    /// Note: when `category` is `None` no `category` field is written to the Qdrant payload.
308    /// Memories stored before category-aware recall was enabled therefore won't match a
309    /// category filter — this is intentional (no silent false-positives), but a backfill
310    /// pass is needed if retrospective categorization is desired.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
315    #[allow(clippy::too_many_arguments)]
316    pub async fn store_with_category(
317        &self,
318        message_id: MessageId,
319        conversation_id: ConversationId,
320        role: &str,
321        vector: Vec<f32>,
322        kind: MessageKind,
323        model: &str,
324        chunk_index: u32,
325        category: Option<&str>,
326    ) -> Result<String, MemoryError> {
327        let point_id = uuid::Uuid::new_v4().to_string();
328        let dimensions = i64::try_from(vector.len())?;
329
330        let mut payload = std::collections::HashMap::from([
331            ("message_id".to_owned(), serde_json::json!(message_id.0)),
332            (
333                "conversation_id".to_owned(),
334                serde_json::json!(conversation_id.0),
335            ),
336            ("role".to_owned(), serde_json::json!(role)),
337            (
338                "is_summary".to_owned(),
339                serde_json::json!(kind.is_summary()),
340            ),
341        ]);
342        if let Some(cat) = category {
343            payload.insert("category".to_owned(), serde_json::json!(cat));
344        }
345
346        let point = VectorPoint {
347            id: point_id.clone(),
348            vector,
349            payload,
350        };
351
352        self.ops.upsert(&self.collection, vec![point]).await?;
353
354        let chunk_index_i64 = i64::from(chunk_index);
355        zeph_db::query(sql!(
356            "INSERT INTO embeddings_metadata \
357             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
358             VALUES (?, ?, ?, ?, ?) \
359             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
360             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
361        ))
362        .bind(message_id)
363        .bind(chunk_index_i64)
364        .bind(&point_id)
365        .bind(dimensions)
366        .bind(model)
367        .execute(&self.pool)
368        .await?;
369
370        Ok(point_id)
371    }
372
373    /// Search for similar vectors in Qdrant, returning up to `limit` results.
374    ///
375    /// # Errors
376    ///
377    /// Returns an error if the Qdrant search fails.
378    pub async fn search(
379        &self,
380        query_vector: &[f32],
381        limit: usize,
382        filter: Option<SearchFilter>,
383    ) -> Result<Vec<SearchResult>, MemoryError> {
384        let limit_u64 = u64::try_from(limit)?;
385
386        let vector_filter = filter.as_ref().and_then(|f| {
387            let mut must = Vec::new();
388            if let Some(cid) = f.conversation_id {
389                must.push(FieldCondition {
390                    field: "conversation_id".into(),
391                    value: FieldValue::Integer(cid.0),
392                });
393            }
394            if let Some(ref role) = f.role {
395                must.push(FieldCondition {
396                    field: "role".into(),
397                    value: FieldValue::Text(role.clone()),
398                });
399            }
400            if let Some(ref category) = f.category {
401                must.push(FieldCondition {
402                    field: "category".into(),
403                    value: FieldValue::Text(category.clone()),
404                });
405            }
406            if must.is_empty() {
407                None
408            } else {
409                Some(VectorFilter {
410                    must,
411                    must_not: vec![],
412                })
413            }
414        });
415
416        let results = self
417            .ops
418            .search(
419                &self.collection,
420                query_vector.to_vec(),
421                limit_u64,
422                vector_filter,
423            )
424            .await?;
425
426        // Deduplicate by message_id, keeping the chunk with the highest score.
427        // A single message may produce multiple Qdrant points (one per chunk).
428        let mut best: std::collections::HashMap<MessageId, SearchResult> =
429            std::collections::HashMap::new();
430        for point in results {
431            let Some(message_id) = point
432                .payload
433                .get("message_id")
434                .and_then(serde_json::Value::as_i64)
435            else {
436                continue;
437            };
438            let Some(conversation_id) = point
439                .payload
440                .get("conversation_id")
441                .and_then(serde_json::Value::as_i64)
442            else {
443                continue;
444            };
445            let message_id = MessageId(message_id);
446            let entry = best.entry(message_id).or_insert(SearchResult {
447                message_id,
448                conversation_id: ConversationId(conversation_id),
449                score: f32::NEG_INFINITY,
450            });
451            if point.score > entry.score {
452                entry.score = point.score;
453            }
454        }
455
456        let mut search_results: Vec<SearchResult> = best.into_values().collect();
457        search_results.sort_by(|a, b| {
458            b.score
459                .partial_cmp(&a.score)
460                .unwrap_or(std::cmp::Ordering::Equal)
461        });
462        search_results.truncate(limit);
463
464        Ok(search_results)
465    }
466
467    /// Check whether a named collection exists in the vector store.
468    ///
469    /// # Errors
470    ///
471    /// Returns an error if the store backend cannot be reached.
472    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
473        self.ops.collection_exists(name).await.map_err(Into::into)
474    }
475
476    /// Ensure a named collection exists in Qdrant with the given vector size.
477    ///
478    /// # Errors
479    ///
480    /// Returns an error if Qdrant cannot be reached or collection creation fails.
481    pub async fn ensure_named_collection(
482        &self,
483        name: &str,
484        vector_size: u64,
485    ) -> Result<(), MemoryError> {
486        self.ops.ensure_collection(name, vector_size).await?;
487        Ok(())
488    }
489
490    /// Store a vector in a named Qdrant collection with arbitrary payload.
491    ///
492    /// Returns the UUID of the newly created point.
493    ///
494    /// # Errors
495    ///
496    /// Returns an error if the Qdrant upsert fails.
497    pub async fn store_to_collection(
498        &self,
499        collection: &str,
500        payload: serde_json::Value,
501        vector: Vec<f32>,
502    ) -> Result<String, MemoryError> {
503        let point_id = uuid::Uuid::new_v4().to_string();
504        let payload_map: std::collections::HashMap<String, serde_json::Value> =
505            serde_json::from_value(payload)?;
506        let point = VectorPoint {
507            id: point_id.clone(),
508            vector,
509            payload: payload_map,
510        };
511        self.ops.upsert(collection, vec![point]).await?;
512        Ok(point_id)
513    }
514
515    /// Upsert a vector into a named collection, reusing an existing point ID.
516    ///
517    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if the Qdrant upsert fails.
522    pub async fn upsert_to_collection(
523        &self,
524        collection: &str,
525        point_id: &str,
526        payload: serde_json::Value,
527        vector: Vec<f32>,
528    ) -> Result<(), MemoryError> {
529        let payload_map: std::collections::HashMap<String, serde_json::Value> =
530            serde_json::from_value(payload)?;
531        let point = VectorPoint {
532            id: point_id.to_owned(),
533            vector,
534            payload: payload_map,
535        };
536        self.ops.upsert(collection, vec![point]).await?;
537        Ok(())
538    }
539
540    /// Search a named Qdrant collection, returning scored points with payloads.
541    ///
542    /// # Errors
543    ///
544    /// Returns an error if the Qdrant search fails.
545    pub async fn search_collection(
546        &self,
547        collection: &str,
548        query_vector: &[f32],
549        limit: usize,
550        filter: Option<VectorFilter>,
551    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
552        let limit_u64 = u64::try_from(limit)?;
553        let results = self
554            .ops
555            .search(collection, query_vector.to_vec(), limit_u64, filter)
556            .await?;
557        Ok(results)
558    }
559
560    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
561    ///
562    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
563    ///
564    /// # Errors
565    ///
566    /// Returns an error if the `SQLite` query fails.
567    pub async fn get_vectors(
568        &self,
569        ids: &[MessageId],
570    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
571        if ids.is_empty() {
572            return Ok(std::collections::HashMap::new());
573        }
574
575        let placeholders = zeph_db::placeholder_list(1, ids.len());
576        let query = format!(
577            "SELECT em.message_id, vp.vector \
578             FROM embeddings_metadata em \
579             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
580             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
581        );
582        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
583        for &id in ids {
584            q = q.bind(id);
585        }
586
587        let rows = q.fetch_all(&self.pool).await?;
588
589        let map = rows
590            .into_iter()
591            .filter_map(|(msg_id, blob)| {
592                if blob.len() % 4 != 0 {
593                    return None;
594                }
595                let vec: Vec<f32> = blob
596                    .chunks_exact(4)
597                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
598                    .collect();
599                Some((msg_id, vec))
600            })
601            .collect();
602
603        Ok(map)
604    }
605
606    /// Check whether an embedding already exists for the given message ID.
607    ///
608    /// # Errors
609    ///
610    /// Returns an error if the `SQLite` query fails.
611    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
612        let row: (i64,) = zeph_db::query_as(sql!(
613            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
614        ))
615        .bind(message_id)
616        .fetch_one(&self.pool)
617        .await?;
618
619        Ok(row.0 > 0)
620    }
621
622    /// Check whether a Qdrant embedding for `entity_name` is current by comparing the
623    /// Qdrant-side epoch against the epoch stored in `graph_entities`.
624    ///
625    /// Returns `true` if the Qdrant embedding is up-to-date or if the entity no longer
626    /// exists in `SQLite` (embedding should be cleaned up separately).
627    ///
628    /// # Errors
629    ///
630    /// Returns an error if the `SQLite` query fails.
631    pub async fn is_epoch_current(
632        &self,
633        entity_name: &str,
634        qdrant_epoch: u64,
635    ) -> Result<bool, MemoryError> {
636        let row: Option<(i64,)> = zeph_db::query_as(sql!(
637            "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
638        ))
639        .bind(entity_name)
640        .fetch_optional(&self.pool)
641        .await?;
642
643        match row {
644            None => Ok(true), // entity deleted; Qdrant point is orphaned, not stale per epoch
645            Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
646        }
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use crate::in_memory_store::InMemoryVectorStore;
654    use crate::store::SqliteStore;
655
656    async fn setup() -> (SqliteStore, DbPool) {
657        let store = SqliteStore::new(":memory:").await.unwrap();
658        let pool = store.pool().clone();
659        (store, pool)
660    }
661
662    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
663        let sqlite = SqliteStore::new(":memory:").await.unwrap();
664        let pool = sqlite.pool().clone();
665        let mem_store = Box::new(InMemoryVectorStore::new());
666        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
667        // Create collection first
668        embedding_store.ensure_collection(4).await.unwrap();
669        (embedding_store, sqlite)
670    }
671
672    #[tokio::test]
673    async fn has_embedding_returns_false_when_none() {
674        let (_store, pool) = setup().await;
675
676        let row: (i64,) = zeph_db::query_as(sql!(
677            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
678        ))
679        .bind(999_i64)
680        .fetch_one(&pool)
681        .await
682        .unwrap();
683
684        assert_eq!(row.0, 0);
685    }
686
687    #[tokio::test]
688    async fn insert_and_query_embeddings_metadata() {
689        let (sqlite, pool) = setup().await;
690        let cid = sqlite.create_conversation().await.unwrap();
691        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
692
693        let point_id = uuid::Uuid::new_v4().to_string();
694        zeph_db::query(sql!(
695            "INSERT INTO embeddings_metadata \
696             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
697             VALUES (?, ?, ?, ?, ?)"
698        ))
699        .bind(msg_id)
700        .bind(0_i64)
701        .bind(&point_id)
702        .bind(768_i64)
703        .bind("qwen3-embedding")
704        .execute(&pool)
705        .await
706        .unwrap();
707
708        let row: (i64,) = zeph_db::query_as(sql!(
709            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
710        ))
711        .bind(msg_id)
712        .fetch_one(&pool)
713        .await
714        .unwrap();
715        assert_eq!(row.0, 1);
716    }
717
718    #[tokio::test]
719    async fn embedding_store_search_empty_returns_empty() {
720        let (store, _sqlite) = setup_with_store().await;
721        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
722        assert!(results.is_empty());
723    }
724
725    #[tokio::test]
726    async fn embedding_store_store_and_search() {
727        let (store, sqlite) = setup_with_store().await;
728        let cid = sqlite.create_conversation().await.unwrap();
729        let msg_id = sqlite
730            .save_message(cid, "user", "test message")
731            .await
732            .unwrap();
733
734        store
735            .store(
736                msg_id,
737                cid,
738                "user",
739                vec![1.0, 0.0, 0.0, 0.0],
740                MessageKind::Regular,
741                "test-model",
742                0,
743            )
744            .await
745            .unwrap();
746
747        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
748        assert_eq!(results.len(), 1);
749        assert_eq!(results[0].message_id, msg_id);
750        assert_eq!(results[0].conversation_id, cid);
751        assert!((results[0].score - 1.0).abs() < 0.001);
752    }
753
754    #[tokio::test]
755    async fn embedding_store_has_embedding_false_for_unknown() {
756        let (store, sqlite) = setup_with_store().await;
757        let cid = sqlite.create_conversation().await.unwrap();
758        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
759        assert!(!store.has_embedding(msg_id).await.unwrap());
760    }
761
762    #[tokio::test]
763    async fn embedding_store_has_embedding_true_after_store() {
764        let (store, sqlite) = setup_with_store().await;
765        let cid = sqlite.create_conversation().await.unwrap();
766        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
767
768        store
769            .store(
770                msg_id,
771                cid,
772                "user",
773                vec![0.0, 1.0, 0.0, 0.0],
774                MessageKind::Regular,
775                "test-model",
776                0,
777            )
778            .await
779            .unwrap();
780
781        assert!(store.has_embedding(msg_id).await.unwrap());
782    }
783
784    #[tokio::test]
785    async fn embedding_store_search_with_conversation_filter() {
786        let (store, sqlite) = setup_with_store().await;
787        let cid1 = sqlite.create_conversation().await.unwrap();
788        let cid2 = sqlite.create_conversation().await.unwrap();
789        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
790        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
791
792        store
793            .store(
794                msg1,
795                cid1,
796                "user",
797                vec![1.0, 0.0, 0.0, 0.0],
798                MessageKind::Regular,
799                "m",
800                0,
801            )
802            .await
803            .unwrap();
804        store
805            .store(
806                msg2,
807                cid2,
808                "user",
809                vec![1.0, 0.0, 0.0, 0.0],
810                MessageKind::Regular,
811                "m",
812                0,
813            )
814            .await
815            .unwrap();
816
817        let results = store
818            .search(
819                &[1.0, 0.0, 0.0, 0.0],
820                10,
821                Some(SearchFilter {
822                    conversation_id: Some(cid1),
823                    role: None,
824                    category: None,
825                }),
826            )
827            .await
828            .unwrap();
829        assert_eq!(results.len(), 1);
830        assert_eq!(results[0].conversation_id, cid1);
831    }
832
833    #[tokio::test]
834    async fn unique_constraint_on_message_chunk_and_model() {
835        let (sqlite, pool) = setup().await;
836        let cid = sqlite.create_conversation().await.unwrap();
837        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
838
839        let point_id1 = uuid::Uuid::new_v4().to_string();
840        zeph_db::query(sql!(
841            "INSERT INTO embeddings_metadata \
842             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
843             VALUES (?, ?, ?, ?, ?)"
844        ))
845        .bind(msg_id)
846        .bind(0_i64)
847        .bind(&point_id1)
848        .bind(768_i64)
849        .bind("qwen3-embedding")
850        .execute(&pool)
851        .await
852        .unwrap();
853
854        // Same (message_id, chunk_index, model) — must fail.
855        let point_id2 = uuid::Uuid::new_v4().to_string();
856        let result = zeph_db::query(sql!(
857            "INSERT INTO embeddings_metadata \
858             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
859             VALUES (?, ?, ?, ?, ?)"
860        ))
861        .bind(msg_id)
862        .bind(0_i64)
863        .bind(&point_id2)
864        .bind(768_i64)
865        .bind("qwen3-embedding")
866        .execute(&pool)
867        .await;
868        assert!(result.is_err());
869
870        // Different chunk_index — must succeed.
871        let point_id3 = uuid::Uuid::new_v4().to_string();
872        zeph_db::query(sql!(
873            "INSERT INTO embeddings_metadata \
874             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
875             VALUES (?, ?, ?, ?, ?)"
876        ))
877        .bind(msg_id)
878        .bind(1_i64)
879        .bind(&point_id3)
880        .bind(768_i64)
881        .bind("qwen3-embedding")
882        .execute(&pool)
883        .await
884        .unwrap();
885    }
886}