Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15/// Distinguishes regular messages from summaries when storing embeddings.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18    Regular,
19    Summary,
20}
21
22impl MessageKind {
23    #[must_use]
24    pub fn is_summary(self) -> bool {
25        matches!(self, Self::Summary)
26    }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31/// Ensure a Qdrant collection exists with cosine distance vectors.
32///
33/// Idempotent: no-op if the collection already exists.
34///
35/// # Errors
36///
37/// Returns an error if Qdrant cannot be reached or collection creation fails.
38pub async fn ensure_qdrant_collection(
39    ops: &QdrantOps,
40    collection: &str,
41    vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43    ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47    ops: Box<dyn VectorStore>,
48    collection: String,
49    pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("EmbeddingStore")
55            .field("collection", &self.collection)
56            .finish_non_exhaustive()
57    }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62    pub conversation_id: Option<ConversationId>,
63    pub role: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct SearchResult {
68    pub message_id: MessageId,
69    pub conversation_id: ConversationId,
70    pub score: f32,
71}
72
73impl EmbeddingStore {
74    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
75    ///
76    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
77    /// table (which must already exist via sqlx migrations).
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the Qdrant client cannot be created.
82    pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
83        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
84
85        Ok(Self {
86            ops: Box::new(ops),
87            collection: COLLECTION_NAME.into(),
88            pool,
89        })
90    }
91
92    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
93    ///
94    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
95    #[must_use]
96    pub fn new_sqlite(pool: DbPool) -> Self {
97        let ops = DbVectorStore::new(pool.clone());
98        Self {
99            ops: Box::new(ops),
100            collection: COLLECTION_NAME.into(),
101            pool,
102        }
103    }
104
105    #[must_use]
106    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
107        Self {
108            ops: store,
109            collection: COLLECTION_NAME.into(),
110            pool,
111        }
112    }
113
114    pub async fn health_check(&self) -> bool {
115        self.ops.health_check().await.unwrap_or(false)
116    }
117
118    /// Ensure the collection exists in Qdrant with the given vector size.
119    ///
120    /// Idempotent: no-op if the collection already exists.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if Qdrant cannot be reached or collection creation fails.
125    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
126        self.ops
127            .ensure_collection(&self.collection, vector_size)
128            .await?;
129        Ok(())
130    }
131
132    /// Store a vector in Qdrant and persist metadata to `SQLite`.
133    ///
134    /// Returns the UUID of the newly created Qdrant point.
135    ///
136    /// # Errors
137    ///
138    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
139    pub async fn store(
140        &self,
141        message_id: MessageId,
142        conversation_id: ConversationId,
143        role: &str,
144        vector: Vec<f32>,
145        kind: MessageKind,
146        model: &str,
147    ) -> Result<String, MemoryError> {
148        let point_id = uuid::Uuid::new_v4().to_string();
149        let dimensions = i64::try_from(vector.len())?;
150
151        let payload = std::collections::HashMap::from([
152            ("message_id".to_owned(), serde_json::json!(message_id.0)),
153            (
154                "conversation_id".to_owned(),
155                serde_json::json!(conversation_id.0),
156            ),
157            ("role".to_owned(), serde_json::json!(role)),
158            (
159                "is_summary".to_owned(),
160                serde_json::json!(kind.is_summary()),
161            ),
162        ]);
163
164        let point = VectorPoint {
165            id: point_id.clone(),
166            vector,
167            payload,
168        };
169
170        self.ops.upsert(&self.collection, vec![point]).await?;
171
172        zeph_db::query(sql!(
173            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
174             VALUES (?, ?, ?, ?) \
175             ON CONFLICT(message_id, model) DO UPDATE SET \
176             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
177        ))
178        .bind(message_id)
179        .bind(&point_id)
180        .bind(dimensions)
181        .bind(model)
182        .execute(&self.pool)
183        .await?;
184
185        Ok(point_id)
186    }
187
188    /// Search for similar vectors in Qdrant, returning up to `limit` results.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if the Qdrant search fails.
193    pub async fn search(
194        &self,
195        query_vector: &[f32],
196        limit: usize,
197        filter: Option<SearchFilter>,
198    ) -> Result<Vec<SearchResult>, MemoryError> {
199        let limit_u64 = u64::try_from(limit)?;
200
201        let vector_filter = filter.as_ref().and_then(|f| {
202            let mut must = Vec::new();
203            if let Some(cid) = f.conversation_id {
204                must.push(FieldCondition {
205                    field: "conversation_id".into(),
206                    value: FieldValue::Integer(cid.0),
207                });
208            }
209            if let Some(ref role) = f.role {
210                must.push(FieldCondition {
211                    field: "role".into(),
212                    value: FieldValue::Text(role.clone()),
213                });
214            }
215            if must.is_empty() {
216                None
217            } else {
218                Some(VectorFilter {
219                    must,
220                    must_not: vec![],
221                })
222            }
223        });
224
225        let results = self
226            .ops
227            .search(
228                &self.collection,
229                query_vector.to_vec(),
230                limit_u64,
231                vector_filter,
232            )
233            .await?;
234
235        let search_results = results
236            .into_iter()
237            .filter_map(|point| {
238                let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
239                let conversation_id =
240                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
241                Some(SearchResult {
242                    message_id,
243                    conversation_id,
244                    score: point.score,
245                })
246            })
247            .collect();
248
249        Ok(search_results)
250    }
251
252    /// Check whether a named collection exists in the vector store.
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if the store backend cannot be reached.
257    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
258        self.ops.collection_exists(name).await.map_err(Into::into)
259    }
260
261    /// Ensure a named collection exists in Qdrant with the given vector size.
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if Qdrant cannot be reached or collection creation fails.
266    pub async fn ensure_named_collection(
267        &self,
268        name: &str,
269        vector_size: u64,
270    ) -> Result<(), MemoryError> {
271        self.ops.ensure_collection(name, vector_size).await?;
272        Ok(())
273    }
274
275    /// Store a vector in a named Qdrant collection with arbitrary payload.
276    ///
277    /// Returns the UUID of the newly created point.
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if the Qdrant upsert fails.
282    pub async fn store_to_collection(
283        &self,
284        collection: &str,
285        payload: serde_json::Value,
286        vector: Vec<f32>,
287    ) -> Result<String, MemoryError> {
288        let point_id = uuid::Uuid::new_v4().to_string();
289        let payload_map: std::collections::HashMap<String, serde_json::Value> =
290            serde_json::from_value(payload)?;
291        let point = VectorPoint {
292            id: point_id.clone(),
293            vector,
294            payload: payload_map,
295        };
296        self.ops.upsert(collection, vec![point]).await?;
297        Ok(point_id)
298    }
299
300    /// Upsert a vector into a named collection, reusing an existing point ID.
301    ///
302    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
303    ///
304    /// # Errors
305    ///
306    /// Returns an error if the Qdrant upsert fails.
307    pub async fn upsert_to_collection(
308        &self,
309        collection: &str,
310        point_id: &str,
311        payload: serde_json::Value,
312        vector: Vec<f32>,
313    ) -> Result<(), MemoryError> {
314        let payload_map: std::collections::HashMap<String, serde_json::Value> =
315            serde_json::from_value(payload)?;
316        let point = VectorPoint {
317            id: point_id.to_owned(),
318            vector,
319            payload: payload_map,
320        };
321        self.ops.upsert(collection, vec![point]).await?;
322        Ok(())
323    }
324
325    /// Search a named Qdrant collection, returning scored points with payloads.
326    ///
327    /// # Errors
328    ///
329    /// Returns an error if the Qdrant search fails.
330    pub async fn search_collection(
331        &self,
332        collection: &str,
333        query_vector: &[f32],
334        limit: usize,
335        filter: Option<VectorFilter>,
336    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
337        let limit_u64 = u64::try_from(limit)?;
338        let results = self
339            .ops
340            .search(collection, query_vector.to_vec(), limit_u64, filter)
341            .await?;
342        Ok(results)
343    }
344
345    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
346    ///
347    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
348    ///
349    /// # Errors
350    ///
351    /// Returns an error if the `SQLite` query fails.
352    pub async fn get_vectors(
353        &self,
354        ids: &[MessageId],
355    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
356        if ids.is_empty() {
357            return Ok(std::collections::HashMap::new());
358        }
359
360        let placeholders = zeph_db::placeholder_list(1, ids.len());
361        let query = format!(
362            "SELECT em.message_id, vp.vector \
363             FROM embeddings_metadata em \
364             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
365             WHERE em.message_id IN ({placeholders})"
366        );
367        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
368        for &id in ids {
369            q = q.bind(id);
370        }
371
372        let rows = q.fetch_all(&self.pool).await?;
373
374        let map = rows
375            .into_iter()
376            .filter_map(|(msg_id, blob)| {
377                if blob.len() % 4 != 0 {
378                    return None;
379                }
380                let vec: Vec<f32> = blob
381                    .chunks_exact(4)
382                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
383                    .collect();
384                Some((msg_id, vec))
385            })
386            .collect();
387
388        Ok(map)
389    }
390
391    /// Check whether an embedding already exists for the given message ID.
392    ///
393    /// # Errors
394    ///
395    /// Returns an error if the `SQLite` query fails.
396    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
397        let row: (i64,) = zeph_db::query_as(sql!(
398            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
399        ))
400        .bind(message_id)
401        .fetch_one(&self.pool)
402        .await?;
403
404        Ok(row.0 > 0)
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::in_memory_store::InMemoryVectorStore;
412    use crate::store::SqliteStore;
413
414    async fn setup() -> (SqliteStore, DbPool) {
415        let store = SqliteStore::new(":memory:").await.unwrap();
416        let pool = store.pool().clone();
417        (store, pool)
418    }
419
420    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
421        let sqlite = SqliteStore::new(":memory:").await.unwrap();
422        let pool = sqlite.pool().clone();
423        let mem_store = Box::new(InMemoryVectorStore::new());
424        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
425        // Create collection first
426        embedding_store.ensure_collection(4).await.unwrap();
427        (embedding_store, sqlite)
428    }
429
430    #[tokio::test]
431    async fn has_embedding_returns_false_when_none() {
432        let (_store, pool) = setup().await;
433
434        let row: (i64,) = zeph_db::query_as(sql!(
435            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
436        ))
437        .bind(999_i64)
438        .fetch_one(&pool)
439        .await
440        .unwrap();
441
442        assert_eq!(row.0, 0);
443    }
444
445    #[tokio::test]
446    async fn insert_and_query_embeddings_metadata() {
447        let (sqlite, pool) = setup().await;
448        let cid = sqlite.create_conversation().await.unwrap();
449        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
450
451        let point_id = uuid::Uuid::new_v4().to_string();
452        zeph_db::query(sql!(
453            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
454             VALUES (?, ?, ?, ?)"
455        ))
456        .bind(msg_id)
457        .bind(&point_id)
458        .bind(768_i64)
459        .bind("qwen3-embedding")
460        .execute(&pool)
461        .await
462        .unwrap();
463
464        let row: (i64,) = zeph_db::query_as(sql!(
465            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
466        ))
467        .bind(msg_id)
468        .fetch_one(&pool)
469        .await
470        .unwrap();
471        assert_eq!(row.0, 1);
472    }
473
474    #[tokio::test]
475    async fn embedding_store_search_empty_returns_empty() {
476        let (store, _sqlite) = setup_with_store().await;
477        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
478        assert!(results.is_empty());
479    }
480
481    #[tokio::test]
482    async fn embedding_store_store_and_search() {
483        let (store, sqlite) = setup_with_store().await;
484        let cid = sqlite.create_conversation().await.unwrap();
485        let msg_id = sqlite
486            .save_message(cid, "user", "test message")
487            .await
488            .unwrap();
489
490        store
491            .store(
492                msg_id,
493                cid,
494                "user",
495                vec![1.0, 0.0, 0.0, 0.0],
496                MessageKind::Regular,
497                "test-model",
498            )
499            .await
500            .unwrap();
501
502        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
503        assert_eq!(results.len(), 1);
504        assert_eq!(results[0].message_id, msg_id);
505        assert_eq!(results[0].conversation_id, cid);
506        assert!((results[0].score - 1.0).abs() < 0.001);
507    }
508
509    #[tokio::test]
510    async fn embedding_store_has_embedding_false_for_unknown() {
511        let (store, sqlite) = setup_with_store().await;
512        let cid = sqlite.create_conversation().await.unwrap();
513        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
514        assert!(!store.has_embedding(msg_id).await.unwrap());
515    }
516
517    #[tokio::test]
518    async fn embedding_store_has_embedding_true_after_store() {
519        let (store, sqlite) = setup_with_store().await;
520        let cid = sqlite.create_conversation().await.unwrap();
521        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
522
523        store
524            .store(
525                msg_id,
526                cid,
527                "user",
528                vec![0.0, 1.0, 0.0, 0.0],
529                MessageKind::Regular,
530                "test-model",
531            )
532            .await
533            .unwrap();
534
535        assert!(store.has_embedding(msg_id).await.unwrap());
536    }
537
538    #[tokio::test]
539    async fn embedding_store_search_with_conversation_filter() {
540        let (store, sqlite) = setup_with_store().await;
541        let cid1 = sqlite.create_conversation().await.unwrap();
542        let cid2 = sqlite.create_conversation().await.unwrap();
543        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
544        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
545
546        store
547            .store(
548                msg1,
549                cid1,
550                "user",
551                vec![1.0, 0.0, 0.0, 0.0],
552                MessageKind::Regular,
553                "m",
554            )
555            .await
556            .unwrap();
557        store
558            .store(
559                msg2,
560                cid2,
561                "user",
562                vec![1.0, 0.0, 0.0, 0.0],
563                MessageKind::Regular,
564                "m",
565            )
566            .await
567            .unwrap();
568
569        let results = store
570            .search(
571                &[1.0, 0.0, 0.0, 0.0],
572                10,
573                Some(SearchFilter {
574                    conversation_id: Some(cid1),
575                    role: None,
576                }),
577            )
578            .await
579            .unwrap();
580        assert_eq!(results.len(), 1);
581        assert_eq!(results[0].conversation_id, cid1);
582    }
583
584    #[tokio::test]
585    async fn unique_constraint_on_message_and_model() {
586        let (sqlite, pool) = setup().await;
587        let cid = sqlite.create_conversation().await.unwrap();
588        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
589
590        let point_id1 = uuid::Uuid::new_v4().to_string();
591        zeph_db::query(sql!(
592            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
593             VALUES (?, ?, ?, ?)"
594        ))
595        .bind(msg_id)
596        .bind(&point_id1)
597        .bind(768_i64)
598        .bind("qwen3-embedding")
599        .execute(&pool)
600        .await
601        .unwrap();
602
603        let point_id2 = uuid::Uuid::new_v4().to_string();
604        let result = zeph_db::query(sql!(
605            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
606             VALUES (?, ?, ?, ?)"
607        ))
608        .bind(msg_id)
609        .bind(&point_id2)
610        .bind(768_i64)
611        .bind("qwen3-embedding")
612        .execute(&pool)
613        .await;
614
615        assert!(result.is_err());
616    }
617}