Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Qdrant-backed embedding store for message vector search.
5//!
6//! [`EmbeddingStore`] owns a [`VectorStore`] implementation (Qdrant in production,
7//! [`crate::db_vector_store::DbVectorStore`] in tests) and exposes typed `embed` /
8//! `search` / `delete` operations used by [`crate::semantic::SemanticMemory`].
9//!
10//! Message vectors are stored in the `zeph_conversations` Qdrant collection with a
11//! payload that includes `message_id`, `conversation_id`, `role`, and `category`.
12
13pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24/// Distinguishes regular messages from summaries when storing embeddings.
25///
26/// The kind is encoded in the Qdrant payload so search filters can restrict
27/// results to one category.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30    /// A normal conversation message.
31    Regular,
32    /// A compression summary generated by the summarization subsystem.
33    Summary,
34}
35
36impl MessageKind {
37    #[must_use]
38    pub fn is_summary(self) -> bool {
39        matches!(self, Self::Summary)
40    }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45/// Ensure a Qdrant collection exists with cosine distance vectors.
46///
47/// Idempotent: no-op if the collection already exists.
48///
49/// # Errors
50///
51/// Returns an error if Qdrant cannot be reached or collection creation fails.
52pub async fn ensure_qdrant_collection(
53    ops: &QdrantOps,
54    collection: &str,
55    vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57    ops.ensure_collection(collection, vector_size).await
58}
59
60/// Typed wrapper over a [`VectorStore`] backend for conversation message embeddings.
61///
62/// Constructed via [`EmbeddingStore::new`] (Qdrant URL) or
63/// [`EmbeddingStore::with_store`] (custom backend for testing).
64pub struct EmbeddingStore {
65    ops: Box<dyn VectorStore>,
66    collection: String,
67    pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("EmbeddingStore")
73            .field("collection", &self.collection)
74            .finish_non_exhaustive()
75    }
76}
77
78/// Optional filters applied to a vector similarity search.
79#[derive(Debug)]
80pub struct SearchFilter {
81    /// Restrict results to a single conversation. `None` searches across all conversations.
82    pub conversation_id: Option<ConversationId>,
83    /// Restrict by message role (`"user"` / `"assistant"`). `None` returns all roles.
84    pub role: Option<String>,
85    /// Restrict by category payload field (category-aware memory, #2428).
86    /// When `Some`, Qdrant search is restricted to vectors with a matching `category` payload.
87    pub category: Option<String>,
88}
89
90/// A single result returned by [`EmbeddingStore::search`].
91#[derive(Debug)]
92pub struct SearchResult {
93    /// Database row ID of the matching message.
94    pub message_id: MessageId,
95    /// Conversation the message belongs to.
96    pub conversation_id: ConversationId,
97    /// Cosine similarity score in `[0, 1]`.
98    pub score: f32,
99}
100
101impl EmbeddingStore {
102    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
103    ///
104    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
105    /// table (which must already exist via sqlx migrations).
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the Qdrant client cannot be created.
110    pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
111        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
112
113        Ok(Self {
114            ops: Box::new(ops),
115            collection: COLLECTION_NAME.into(),
116            pool,
117        })
118    }
119
120    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
121    ///
122    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
123    #[must_use]
124    pub fn new_sqlite(pool: DbPool) -> Self {
125        let ops = DbVectorStore::new(pool.clone());
126        Self {
127            ops: Box::new(ops),
128            collection: COLLECTION_NAME.into(),
129            pool,
130        }
131    }
132
133    #[must_use]
134    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
135        Self {
136            ops: store,
137            collection: COLLECTION_NAME.into(),
138            pool,
139        }
140    }
141
142    pub async fn health_check(&self) -> bool {
143        self.ops.health_check().await.unwrap_or(false)
144    }
145
146    /// Ensure the collection exists in Qdrant with the given vector size.
147    ///
148    /// Idempotent: no-op if the collection already exists.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if Qdrant cannot be reached or collection creation fails.
153    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
154        self.ops
155            .ensure_collection(&self.collection, vector_size)
156            .await?;
157        // Create keyword indexes for the fields used in filtered recall so Qdrant can satisfy
158        // filter conditions in O(log n) instead of scanning all payload documents.
159        self.ops
160            .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
161            .await?;
162        Ok(())
163    }
164
165    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
166    ///
167    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
168    /// alongside the standard fields. This allows filtering and scoring by tool context
169    /// without corrupting the embedding vector with text prefixes.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
174    #[allow(clippy::too_many_arguments)]
175    pub async fn store_with_tool_context(
176        &self,
177        message_id: MessageId,
178        conversation_id: ConversationId,
179        role: &str,
180        vector: Vec<f32>,
181        kind: MessageKind,
182        model: &str,
183        chunk_index: u32,
184        tool_name: &str,
185        exit_code: Option<i32>,
186        timestamp: Option<&str>,
187    ) -> Result<String, MemoryError> {
188        let point_id = uuid::Uuid::new_v4().to_string();
189        let dimensions = i64::try_from(vector.len())?;
190
191        let mut payload = std::collections::HashMap::from([
192            ("message_id".to_owned(), serde_json::json!(message_id.0)),
193            (
194                "conversation_id".to_owned(),
195                serde_json::json!(conversation_id.0),
196            ),
197            ("role".to_owned(), serde_json::json!(role)),
198            (
199                "is_summary".to_owned(),
200                serde_json::json!(kind.is_summary()),
201            ),
202            ("tool_name".to_owned(), serde_json::json!(tool_name)),
203        ]);
204        if let Some(code) = exit_code {
205            payload.insert("exit_code".to_owned(), serde_json::json!(code));
206        }
207        if let Some(ts) = timestamp {
208            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
209        }
210
211        let point = VectorPoint {
212            id: point_id.clone(),
213            vector,
214            payload,
215        };
216
217        self.ops.upsert(&self.collection, vec![point]).await?;
218
219        let chunk_index_i64 = i64::from(chunk_index);
220        zeph_db::query(sql!(
221            "INSERT INTO embeddings_metadata \
222             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
223             VALUES (?, ?, ?, ?, ?) \
224             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
225             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
226        ))
227        .bind(message_id)
228        .bind(chunk_index_i64)
229        .bind(&point_id)
230        .bind(dimensions)
231        .bind(model)
232        .execute(&self.pool)
233        .await?;
234
235        Ok(point_id)
236    }
237
238    /// Store a vector in Qdrant and persist metadata to `SQLite`.
239    ///
240    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
241    /// when a long message is split into multiple embeddings.
242    ///
243    /// Returns the UUID of the newly created Qdrant point.
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
248    #[allow(clippy::too_many_arguments)]
249    pub async fn store(
250        &self,
251        message_id: MessageId,
252        conversation_id: ConversationId,
253        role: &str,
254        vector: Vec<f32>,
255        kind: MessageKind,
256        model: &str,
257        chunk_index: u32,
258    ) -> Result<String, MemoryError> {
259        let point_id = uuid::Uuid::new_v4().to_string();
260        let dimensions = i64::try_from(vector.len())?;
261
262        let payload = std::collections::HashMap::from([
263            ("message_id".to_owned(), serde_json::json!(message_id.0)),
264            (
265                "conversation_id".to_owned(),
266                serde_json::json!(conversation_id.0),
267            ),
268            ("role".to_owned(), serde_json::json!(role)),
269            (
270                "is_summary".to_owned(),
271                serde_json::json!(kind.is_summary()),
272            ),
273        ]);
274
275        let point = VectorPoint {
276            id: point_id.clone(),
277            vector,
278            payload,
279        };
280
281        self.ops.upsert(&self.collection, vec![point]).await?;
282
283        let chunk_index_i64 = i64::from(chunk_index);
284        zeph_db::query(sql!(
285            "INSERT INTO embeddings_metadata \
286             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
287             VALUES (?, ?, ?, ?, ?) \
288             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
289             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
290        ))
291        .bind(message_id)
292        .bind(chunk_index_i64)
293        .bind(&point_id)
294        .bind(dimensions)
295        .bind(model)
296        .execute(&self.pool)
297        .await?;
298
299        Ok(point_id)
300    }
301
302    /// Store a vector with an optional category tag in the Qdrant payload.
303    ///
304    /// Identical to [`Self::store`] but adds a `category` field to the payload when provided.
305    /// Used by category-aware memory (#2428) to enable category-filtered recall.
306    ///
307    /// Note: when `category` is `None` no `category` field is written to the Qdrant payload.
308    /// Memories stored before category-aware recall was enabled therefore won't match a
309    /// category filter — this is intentional (no silent false-positives), but a backfill
310    /// pass is needed if retrospective categorization is desired.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
315    #[allow(clippy::too_many_arguments)]
316    pub async fn store_with_category(
317        &self,
318        message_id: MessageId,
319        conversation_id: ConversationId,
320        role: &str,
321        vector: Vec<f32>,
322        kind: MessageKind,
323        model: &str,
324        chunk_index: u32,
325        category: Option<&str>,
326    ) -> Result<String, MemoryError> {
327        let point_id = uuid::Uuid::new_v4().to_string();
328        let dimensions = i64::try_from(vector.len())?;
329
330        let mut payload = std::collections::HashMap::from([
331            ("message_id".to_owned(), serde_json::json!(message_id.0)),
332            (
333                "conversation_id".to_owned(),
334                serde_json::json!(conversation_id.0),
335            ),
336            ("role".to_owned(), serde_json::json!(role)),
337            (
338                "is_summary".to_owned(),
339                serde_json::json!(kind.is_summary()),
340            ),
341        ]);
342        if let Some(cat) = category {
343            payload.insert("category".to_owned(), serde_json::json!(cat));
344        }
345
346        let point = VectorPoint {
347            id: point_id.clone(),
348            vector,
349            payload,
350        };
351
352        self.ops.upsert(&self.collection, vec![point]).await?;
353
354        let chunk_index_i64 = i64::from(chunk_index);
355        zeph_db::query(sql!(
356            "INSERT INTO embeddings_metadata \
357             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
358             VALUES (?, ?, ?, ?, ?) \
359             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
360             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
361        ))
362        .bind(message_id)
363        .bind(chunk_index_i64)
364        .bind(&point_id)
365        .bind(dimensions)
366        .bind(model)
367        .execute(&self.pool)
368        .await?;
369
370        Ok(point_id)
371    }
372
373    /// Search for similar vectors in Qdrant, returning up to `limit` results.
374    ///
375    /// # Errors
376    ///
377    /// Returns an error if the Qdrant search fails.
378    pub async fn search(
379        &self,
380        query_vector: &[f32],
381        limit: usize,
382        filter: Option<SearchFilter>,
383    ) -> Result<Vec<SearchResult>, MemoryError> {
384        let limit_u64 = u64::try_from(limit)?;
385
386        let vector_filter = filter.as_ref().and_then(|f| {
387            let mut must = Vec::new();
388            if let Some(cid) = f.conversation_id {
389                must.push(FieldCondition {
390                    field: "conversation_id".into(),
391                    value: FieldValue::Integer(cid.0),
392                });
393            }
394            if let Some(ref role) = f.role {
395                must.push(FieldCondition {
396                    field: "role".into(),
397                    value: FieldValue::Text(role.clone()),
398                });
399            }
400            if let Some(ref category) = f.category {
401                must.push(FieldCondition {
402                    field: "category".into(),
403                    value: FieldValue::Text(category.clone()),
404                });
405            }
406            if must.is_empty() {
407                None
408            } else {
409                Some(VectorFilter {
410                    must,
411                    must_not: vec![],
412                })
413            }
414        });
415
416        let results = self
417            .ops
418            .search(
419                &self.collection,
420                query_vector.to_vec(),
421                limit_u64,
422                vector_filter,
423            )
424            .await?;
425
426        // Deduplicate by message_id, keeping the chunk with the highest score.
427        // A single message may produce multiple Qdrant points (one per chunk).
428        let mut best: std::collections::HashMap<MessageId, SearchResult> =
429            std::collections::HashMap::new();
430        for point in results {
431            let Some(message_id) = point
432                .payload
433                .get("message_id")
434                .and_then(serde_json::Value::as_i64)
435            else {
436                continue;
437            };
438            let Some(conversation_id) = point
439                .payload
440                .get("conversation_id")
441                .and_then(serde_json::Value::as_i64)
442            else {
443                continue;
444            };
445            let message_id = MessageId(message_id);
446            let entry = best.entry(message_id).or_insert(SearchResult {
447                message_id,
448                conversation_id: ConversationId(conversation_id),
449                score: f32::NEG_INFINITY,
450            });
451            if point.score > entry.score {
452                entry.score = point.score;
453            }
454        }
455
456        let mut search_results: Vec<SearchResult> = best.into_values().collect();
457        search_results.sort_by(|a, b| {
458            b.score
459                .partial_cmp(&a.score)
460                .unwrap_or(std::cmp::Ordering::Equal)
461        });
462        search_results.truncate(limit);
463
464        Ok(search_results)
465    }
466
467    /// Check whether a named collection exists in the vector store.
468    ///
469    /// # Errors
470    ///
471    /// Returns an error if the store backend cannot be reached.
472    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
473        self.ops.collection_exists(name).await.map_err(Into::into)
474    }
475
476    /// Ensure a named collection exists in Qdrant with the given vector size.
477    ///
478    /// # Errors
479    ///
480    /// Returns an error if Qdrant cannot be reached or collection creation fails.
481    pub async fn ensure_named_collection(
482        &self,
483        name: &str,
484        vector_size: u64,
485    ) -> Result<(), MemoryError> {
486        self.ops.ensure_collection(name, vector_size).await?;
487        Ok(())
488    }
489
490    /// Store a vector in a named Qdrant collection with arbitrary payload.
491    ///
492    /// Returns the UUID of the newly created point.
493    ///
494    /// # Errors
495    ///
496    /// Returns an error if the Qdrant upsert fails.
497    pub async fn store_to_collection(
498        &self,
499        collection: &str,
500        payload: serde_json::Value,
501        vector: Vec<f32>,
502    ) -> Result<String, MemoryError> {
503        let point_id = uuid::Uuid::new_v4().to_string();
504        let payload_map: std::collections::HashMap<String, serde_json::Value> =
505            serde_json::from_value(payload)?;
506        let point = VectorPoint {
507            id: point_id.clone(),
508            vector,
509            payload: payload_map,
510        };
511        self.ops.upsert(collection, vec![point]).await?;
512        Ok(point_id)
513    }
514
515    /// Upsert a vector into a named collection, reusing an existing point ID.
516    ///
517    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if the Qdrant upsert fails.
522    pub async fn upsert_to_collection(
523        &self,
524        collection: &str,
525        point_id: &str,
526        payload: serde_json::Value,
527        vector: Vec<f32>,
528    ) -> Result<(), MemoryError> {
529        let payload_map: std::collections::HashMap<String, serde_json::Value> =
530            serde_json::from_value(payload)?;
531        let point = VectorPoint {
532            id: point_id.to_owned(),
533            vector,
534            payload: payload_map,
535        };
536        self.ops.upsert(collection, vec![point]).await?;
537        Ok(())
538    }
539
540    /// Search a named Qdrant collection, returning scored points with payloads.
541    ///
542    /// # Errors
543    ///
544    /// Returns an error if the Qdrant search fails.
545    pub async fn search_collection(
546        &self,
547        collection: &str,
548        query_vector: &[f32],
549        limit: usize,
550        filter: Option<VectorFilter>,
551    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
552        let limit_u64 = u64::try_from(limit)?;
553        let results = self
554            .ops
555            .search(collection, query_vector.to_vec(), limit_u64, filter)
556            .await?;
557        Ok(results)
558    }
559
560    /// Retrieve raw vectors for the given Qdrant point IDs from `collection`.
561    ///
562    /// Returns a map of `point_id → embedding`. Missing ids are silently dropped.
563    /// Returns an empty map when the backend does not support vector retrieval
564    /// (e.g. `DbVectorStore` / `InMemoryVectorStore` without an override).
565    ///
566    /// # Errors
567    ///
568    /// Returns an error if the underlying store returns a non-`Unsupported` error.
569    pub async fn get_vectors_from_collection(
570        &self,
571        collection: &str,
572        point_ids: &[String],
573    ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
574        if point_ids.is_empty() {
575            return Ok(std::collections::HashMap::new());
576        }
577        match self.ops.get_points(collection, point_ids.to_vec()).await {
578            Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
579            Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
580            Err(e) => Err(MemoryError::VectorStore(e)),
581        }
582    }
583
584    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
585    ///
586    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
587    ///
588    /// # Errors
589    ///
590    /// Returns an error if the `SQLite` query fails.
591    pub async fn get_vectors(
592        &self,
593        ids: &[MessageId],
594    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
595        if ids.is_empty() {
596            return Ok(std::collections::HashMap::new());
597        }
598
599        let placeholders = zeph_db::placeholder_list(1, ids.len());
600        let query = format!(
601            "SELECT em.message_id, vp.vector \
602             FROM embeddings_metadata em \
603             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
604             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
605        );
606        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
607        for &id in ids {
608            q = q.bind(id);
609        }
610
611        let rows = q.fetch_all(&self.pool).await?;
612
613        let map = rows
614            .into_iter()
615            .filter_map(|(msg_id, blob)| {
616                if blob.len() % 4 != 0 {
617                    return None;
618                }
619                let vec: Vec<f32> = blob
620                    .chunks_exact(4)
621                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
622                    .collect();
623                Some((msg_id, vec))
624            })
625            .collect();
626
627        Ok(map)
628    }
629
630    /// Check whether an embedding already exists for the given message ID.
631    ///
632    /// # Errors
633    ///
634    /// Returns an error if the `SQLite` query fails.
635    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
636        let row: (i64,) = zeph_db::query_as(sql!(
637            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
638        ))
639        .bind(message_id)
640        .fetch_one(&self.pool)
641        .await?;
642
643        Ok(row.0 > 0)
644    }
645
646    /// Check whether a Qdrant embedding for `entity_name` is current by comparing the
647    /// Qdrant-side epoch against the epoch stored in `graph_entities`.
648    ///
649    /// Returns `true` if the Qdrant embedding is up-to-date or if the entity no longer
650    /// exists in `SQLite` (embedding should be cleaned up separately).
651    ///
652    /// # Errors
653    ///
654    /// Returns an error if the `SQLite` query fails.
655    pub async fn is_epoch_current(
656        &self,
657        entity_name: &str,
658        qdrant_epoch: u64,
659    ) -> Result<bool, MemoryError> {
660        let row: Option<(i64,)> = zeph_db::query_as(sql!(
661            "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
662        ))
663        .bind(entity_name)
664        .fetch_optional(&self.pool)
665        .await?;
666
667        match row {
668            None => Ok(true), // entity deleted; Qdrant point is orphaned, not stale per epoch
669            Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
670        }
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use crate::in_memory_store::InMemoryVectorStore;
678    use crate::store::SqliteStore;
679
680    async fn setup() -> (SqliteStore, DbPool) {
681        let store = SqliteStore::new(":memory:").await.unwrap();
682        let pool = store.pool().clone();
683        (store, pool)
684    }
685
686    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
687        let sqlite = SqliteStore::new(":memory:").await.unwrap();
688        let pool = sqlite.pool().clone();
689        let mem_store = Box::new(InMemoryVectorStore::new());
690        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
691        // Create collection first
692        embedding_store.ensure_collection(4).await.unwrap();
693        (embedding_store, sqlite)
694    }
695
696    #[tokio::test]
697    async fn has_embedding_returns_false_when_none() {
698        let (_store, pool) = setup().await;
699
700        let row: (i64,) = zeph_db::query_as(sql!(
701            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
702        ))
703        .bind(999_i64)
704        .fetch_one(&pool)
705        .await
706        .unwrap();
707
708        assert_eq!(row.0, 0);
709    }
710
711    #[tokio::test]
712    async fn insert_and_query_embeddings_metadata() {
713        let (sqlite, pool) = setup().await;
714        let cid = sqlite.create_conversation().await.unwrap();
715        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
716
717        let point_id = uuid::Uuid::new_v4().to_string();
718        zeph_db::query(sql!(
719            "INSERT INTO embeddings_metadata \
720             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
721             VALUES (?, ?, ?, ?, ?)"
722        ))
723        .bind(msg_id)
724        .bind(0_i64)
725        .bind(&point_id)
726        .bind(768_i64)
727        .bind("qwen3-embedding")
728        .execute(&pool)
729        .await
730        .unwrap();
731
732        let row: (i64,) = zeph_db::query_as(sql!(
733            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
734        ))
735        .bind(msg_id)
736        .fetch_one(&pool)
737        .await
738        .unwrap();
739        assert_eq!(row.0, 1);
740    }
741
742    #[tokio::test]
743    async fn embedding_store_search_empty_returns_empty() {
744        let (store, _sqlite) = setup_with_store().await;
745        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
746        assert!(results.is_empty());
747    }
748
749    #[tokio::test]
750    async fn embedding_store_store_and_search() {
751        let (store, sqlite) = setup_with_store().await;
752        let cid = sqlite.create_conversation().await.unwrap();
753        let msg_id = sqlite
754            .save_message(cid, "user", "test message")
755            .await
756            .unwrap();
757
758        store
759            .store(
760                msg_id,
761                cid,
762                "user",
763                vec![1.0, 0.0, 0.0, 0.0],
764                MessageKind::Regular,
765                "test-model",
766                0,
767            )
768            .await
769            .unwrap();
770
771        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
772        assert_eq!(results.len(), 1);
773        assert_eq!(results[0].message_id, msg_id);
774        assert_eq!(results[0].conversation_id, cid);
775        assert!((results[0].score - 1.0).abs() < 0.001);
776    }
777
778    #[tokio::test]
779    async fn embedding_store_has_embedding_false_for_unknown() {
780        let (store, sqlite) = setup_with_store().await;
781        let cid = sqlite.create_conversation().await.unwrap();
782        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
783        assert!(!store.has_embedding(msg_id).await.unwrap());
784    }
785
786    #[tokio::test]
787    async fn embedding_store_has_embedding_true_after_store() {
788        let (store, sqlite) = setup_with_store().await;
789        let cid = sqlite.create_conversation().await.unwrap();
790        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
791
792        store
793            .store(
794                msg_id,
795                cid,
796                "user",
797                vec![0.0, 1.0, 0.0, 0.0],
798                MessageKind::Regular,
799                "test-model",
800                0,
801            )
802            .await
803            .unwrap();
804
805        assert!(store.has_embedding(msg_id).await.unwrap());
806    }
807
808    #[tokio::test]
809    async fn embedding_store_search_with_conversation_filter() {
810        let (store, sqlite) = setup_with_store().await;
811        let cid1 = sqlite.create_conversation().await.unwrap();
812        let cid2 = sqlite.create_conversation().await.unwrap();
813        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
814        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
815
816        store
817            .store(
818                msg1,
819                cid1,
820                "user",
821                vec![1.0, 0.0, 0.0, 0.0],
822                MessageKind::Regular,
823                "m",
824                0,
825            )
826            .await
827            .unwrap();
828        store
829            .store(
830                msg2,
831                cid2,
832                "user",
833                vec![1.0, 0.0, 0.0, 0.0],
834                MessageKind::Regular,
835                "m",
836                0,
837            )
838            .await
839            .unwrap();
840
841        let results = store
842            .search(
843                &[1.0, 0.0, 0.0, 0.0],
844                10,
845                Some(SearchFilter {
846                    conversation_id: Some(cid1),
847                    role: None,
848                    category: None,
849                }),
850            )
851            .await
852            .unwrap();
853        assert_eq!(results.len(), 1);
854        assert_eq!(results[0].conversation_id, cid1);
855    }
856
857    #[tokio::test]
858    async fn unique_constraint_on_message_chunk_and_model() {
859        let (sqlite, pool) = setup().await;
860        let cid = sqlite.create_conversation().await.unwrap();
861        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
862
863        let point_id1 = uuid::Uuid::new_v4().to_string();
864        zeph_db::query(sql!(
865            "INSERT INTO embeddings_metadata \
866             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
867             VALUES (?, ?, ?, ?, ?)"
868        ))
869        .bind(msg_id)
870        .bind(0_i64)
871        .bind(&point_id1)
872        .bind(768_i64)
873        .bind("qwen3-embedding")
874        .execute(&pool)
875        .await
876        .unwrap();
877
878        // Same (message_id, chunk_index, model) — must fail.
879        let point_id2 = uuid::Uuid::new_v4().to_string();
880        let result = zeph_db::query(sql!(
881            "INSERT INTO embeddings_metadata \
882             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
883             VALUES (?, ?, ?, ?, ?)"
884        ))
885        .bind(msg_id)
886        .bind(0_i64)
887        .bind(&point_id2)
888        .bind(768_i64)
889        .bind("qwen3-embedding")
890        .execute(&pool)
891        .await;
892        assert!(result.is_err());
893
894        // Different chunk_index — must succeed.
895        let point_id3 = uuid::Uuid::new_v4().to_string();
896        zeph_db::query(sql!(
897            "INSERT INTO embeddings_metadata \
898             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
899             VALUES (?, ?, ?, ?, ?)"
900        ))
901        .bind(msg_id)
902        .bind(1_i64)
903        .bind(&point_id3)
904        .bind(768_i64)
905        .bind("qwen3-embedding")
906        .execute(&pool)
907        .await
908        .unwrap();
909    }
910}