Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub use qdrant_client::qdrant::Filter;
5use sqlx::SqlitePool;
6
7use crate::error::MemoryError;
8use crate::qdrant_ops::QdrantOps;
9use crate::sqlite_vector_store::SqliteVectorStore;
10use crate::types::{ConversationId, MessageId};
11use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
12
13/// Distinguishes regular messages from summaries when storing embeddings.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum MessageKind {
16    Regular,
17    Summary,
18}
19
20impl MessageKind {
21    #[must_use]
22    pub fn is_summary(self) -> bool {
23        matches!(self, Self::Summary)
24    }
25}
26
27const COLLECTION_NAME: &str = "zeph_conversations";
28
29/// Ensure a Qdrant collection exists with cosine distance vectors.
30///
31/// Idempotent: no-op if the collection already exists.
32///
33/// # Errors
34///
35/// Returns an error if Qdrant cannot be reached or collection creation fails.
36pub async fn ensure_qdrant_collection(
37    ops: &QdrantOps,
38    collection: &str,
39    vector_size: u64,
40) -> Result<(), Box<qdrant_client::QdrantError>> {
41    ops.ensure_collection(collection, vector_size).await
42}
43
44pub struct EmbeddingStore {
45    ops: Box<dyn VectorStore>,
46    collection: String,
47    pool: SqlitePool,
48}
49
50impl std::fmt::Debug for EmbeddingStore {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("EmbeddingStore")
53            .field("collection", &self.collection)
54            .finish_non_exhaustive()
55    }
56}
57
58#[derive(Debug)]
59pub struct SearchFilter {
60    pub conversation_id: Option<ConversationId>,
61    pub role: Option<String>,
62}
63
64#[derive(Debug)]
65pub struct SearchResult {
66    pub message_id: MessageId,
67    pub conversation_id: ConversationId,
68    pub score: f32,
69}
70
71impl EmbeddingStore {
72    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
73    ///
74    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
75    /// table (which must already exist via sqlx migrations).
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the Qdrant client cannot be created.
80    pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
81        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
82
83        Ok(Self {
84            ops: Box::new(ops),
85            collection: COLLECTION_NAME.into(),
86            pool,
87        })
88    }
89
90    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
91    ///
92    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
93    #[must_use]
94    pub fn new_sqlite(pool: SqlitePool) -> Self {
95        let ops = SqliteVectorStore::new(pool.clone());
96        Self {
97            ops: Box::new(ops),
98            collection: COLLECTION_NAME.into(),
99            pool,
100        }
101    }
102
103    #[must_use]
104    pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
105        Self {
106            ops: store,
107            collection: COLLECTION_NAME.into(),
108            pool,
109        }
110    }
111
112    pub async fn health_check(&self) -> bool {
113        self.ops.health_check().await.unwrap_or(false)
114    }
115
116    /// Ensure the collection exists in Qdrant with the given vector size.
117    ///
118    /// Idempotent: no-op if the collection already exists.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if Qdrant cannot be reached or collection creation fails.
123    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
124        self.ops
125            .ensure_collection(&self.collection, vector_size)
126            .await?;
127        Ok(())
128    }
129
130    /// Store a vector in Qdrant and persist metadata to `SQLite`.
131    ///
132    /// Returns the UUID of the newly created Qdrant point.
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
137    pub async fn store(
138        &self,
139        message_id: MessageId,
140        conversation_id: ConversationId,
141        role: &str,
142        vector: Vec<f32>,
143        kind: MessageKind,
144        model: &str,
145    ) -> Result<String, MemoryError> {
146        let point_id = uuid::Uuid::new_v4().to_string();
147        let dimensions = i64::try_from(vector.len())?;
148
149        let payload = std::collections::HashMap::from([
150            ("message_id".to_owned(), serde_json::json!(message_id.0)),
151            (
152                "conversation_id".to_owned(),
153                serde_json::json!(conversation_id.0),
154            ),
155            ("role".to_owned(), serde_json::json!(role)),
156            (
157                "is_summary".to_owned(),
158                serde_json::json!(kind.is_summary()),
159            ),
160        ]);
161
162        let point = VectorPoint {
163            id: point_id.clone(),
164            vector,
165            payload,
166        };
167
168        self.ops.upsert(&self.collection, vec![point]).await?;
169
170        sqlx::query(
171            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
172             VALUES (?, ?, ?, ?) \
173             ON CONFLICT(message_id, model) DO UPDATE SET \
174             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
175        )
176        .bind(message_id)
177        .bind(&point_id)
178        .bind(dimensions)
179        .bind(model)
180        .execute(&self.pool)
181        .await?;
182
183        Ok(point_id)
184    }
185
186    /// Search for similar vectors in Qdrant, returning up to `limit` results.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the Qdrant search fails.
191    pub async fn search(
192        &self,
193        query_vector: &[f32],
194        limit: usize,
195        filter: Option<SearchFilter>,
196    ) -> Result<Vec<SearchResult>, MemoryError> {
197        let limit_u64 = u64::try_from(limit)?;
198
199        let vector_filter = filter.as_ref().and_then(|f| {
200            let mut must = Vec::new();
201            if let Some(cid) = f.conversation_id {
202                must.push(FieldCondition {
203                    field: "conversation_id".into(),
204                    value: FieldValue::Integer(cid.0),
205                });
206            }
207            if let Some(ref role) = f.role {
208                must.push(FieldCondition {
209                    field: "role".into(),
210                    value: FieldValue::Text(role.clone()),
211                });
212            }
213            if must.is_empty() {
214                None
215            } else {
216                Some(VectorFilter {
217                    must,
218                    must_not: vec![],
219                })
220            }
221        });
222
223        let results = self
224            .ops
225            .search(
226                &self.collection,
227                query_vector.to_vec(),
228                limit_u64,
229                vector_filter,
230            )
231            .await?;
232
233        let search_results = results
234            .into_iter()
235            .filter_map(|point| {
236                let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
237                let conversation_id =
238                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
239                Some(SearchResult {
240                    message_id,
241                    conversation_id,
242                    score: point.score,
243                })
244            })
245            .collect();
246
247        Ok(search_results)
248    }
249
250    /// Check whether a named collection exists in the vector store.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if the store backend cannot be reached.
255    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
256        self.ops.collection_exists(name).await.map_err(Into::into)
257    }
258
259    /// Ensure a named collection exists in Qdrant with the given vector size.
260    ///
261    /// # Errors
262    ///
263    /// Returns an error if Qdrant cannot be reached or collection creation fails.
264    pub async fn ensure_named_collection(
265        &self,
266        name: &str,
267        vector_size: u64,
268    ) -> Result<(), MemoryError> {
269        self.ops.ensure_collection(name, vector_size).await?;
270        Ok(())
271    }
272
273    /// Store a vector in a named Qdrant collection with arbitrary payload.
274    ///
275    /// Returns the UUID of the newly created point.
276    ///
277    /// # Errors
278    ///
279    /// Returns an error if the Qdrant upsert fails.
280    pub async fn store_to_collection(
281        &self,
282        collection: &str,
283        payload: serde_json::Value,
284        vector: Vec<f32>,
285    ) -> Result<String, MemoryError> {
286        let point_id = uuid::Uuid::new_v4().to_string();
287        let payload_map: std::collections::HashMap<String, serde_json::Value> =
288            serde_json::from_value(payload)?;
289        let point = VectorPoint {
290            id: point_id.clone(),
291            vector,
292            payload: payload_map,
293        };
294        self.ops.upsert(collection, vec![point]).await?;
295        Ok(point_id)
296    }
297
298    /// Upsert a vector into a named collection, reusing an existing point ID.
299    ///
300    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if the Qdrant upsert fails.
305    pub async fn upsert_to_collection(
306        &self,
307        collection: &str,
308        point_id: &str,
309        payload: serde_json::Value,
310        vector: Vec<f32>,
311    ) -> Result<(), MemoryError> {
312        let payload_map: std::collections::HashMap<String, serde_json::Value> =
313            serde_json::from_value(payload)?;
314        let point = VectorPoint {
315            id: point_id.to_owned(),
316            vector,
317            payload: payload_map,
318        };
319        self.ops.upsert(collection, vec![point]).await?;
320        Ok(())
321    }
322
323    /// Search a named Qdrant collection, returning scored points with payloads.
324    ///
325    /// # Errors
326    ///
327    /// Returns an error if the Qdrant search fails.
328    pub async fn search_collection(
329        &self,
330        collection: &str,
331        query_vector: &[f32],
332        limit: usize,
333        filter: Option<VectorFilter>,
334    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
335        let limit_u64 = u64::try_from(limit)?;
336        let results = self
337            .ops
338            .search(collection, query_vector.to_vec(), limit_u64, filter)
339            .await?;
340        Ok(results)
341    }
342
343    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
344    ///
345    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
346    ///
347    /// # Errors
348    ///
349    /// Returns an error if the `SQLite` query fails.
350    pub async fn get_vectors(
351        &self,
352        ids: &[MessageId],
353    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
354        if ids.is_empty() {
355            return Ok(std::collections::HashMap::new());
356        }
357
358        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
359        let query = format!(
360            "SELECT em.message_id, vp.vector \
361             FROM embeddings_metadata em \
362             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
363             WHERE em.message_id IN ({placeholders})"
364        );
365        let mut q = sqlx::query_as::<_, (MessageId, Vec<u8>)>(&query);
366        for &id in ids {
367            q = q.bind(id);
368        }
369
370        let rows = q.fetch_all(&self.pool).await.unwrap_or_default();
371
372        let map = rows
373            .into_iter()
374            .filter_map(|(msg_id, blob)| {
375                if blob.len() % 4 != 0 {
376                    return None;
377                }
378                let vec: Vec<f32> = blob
379                    .chunks_exact(4)
380                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
381                    .collect();
382                Some((msg_id, vec))
383            })
384            .collect();
385
386        Ok(map)
387    }
388
389    /// Check whether an embedding already exists for the given message ID.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the `SQLite` query fails.
394    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
395        let row: (i64,) =
396            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
397                .bind(message_id)
398                .fetch_one(&self.pool)
399                .await?;
400
401        Ok(row.0 > 0)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::in_memory_store::InMemoryVectorStore;
409    use crate::sqlite::SqliteStore;
410
411    async fn setup() -> (SqliteStore, SqlitePool) {
412        let store = SqliteStore::new(":memory:").await.unwrap();
413        let pool = store.pool().clone();
414        (store, pool)
415    }
416
417    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
418        let sqlite = SqliteStore::new(":memory:").await.unwrap();
419        let pool = sqlite.pool().clone();
420        let mem_store = Box::new(InMemoryVectorStore::new());
421        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
422        // Create collection first
423        embedding_store.ensure_collection(4).await.unwrap();
424        (embedding_store, sqlite)
425    }
426
427    #[tokio::test]
428    async fn has_embedding_returns_false_when_none() {
429        let (_store, pool) = setup().await;
430
431        let row: (i64,) =
432            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
433                .bind(999_i64)
434                .fetch_one(&pool)
435                .await
436                .unwrap();
437
438        assert_eq!(row.0, 0);
439    }
440
441    #[tokio::test]
442    async fn insert_and_query_embeddings_metadata() {
443        let (sqlite, pool) = setup().await;
444        let cid = sqlite.create_conversation().await.unwrap();
445        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
446
447        let point_id = uuid::Uuid::new_v4().to_string();
448        sqlx::query(
449            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
450             VALUES (?, ?, ?, ?)",
451        )
452        .bind(msg_id)
453        .bind(&point_id)
454        .bind(768_i64)
455        .bind("qwen3-embedding")
456        .execute(&pool)
457        .await
458        .unwrap();
459
460        let row: (i64,) =
461            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
462                .bind(msg_id)
463                .fetch_one(&pool)
464                .await
465                .unwrap();
466        assert_eq!(row.0, 1);
467    }
468
469    #[tokio::test]
470    async fn embedding_store_search_empty_returns_empty() {
471        let (store, _sqlite) = setup_with_store().await;
472        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
473        assert!(results.is_empty());
474    }
475
476    #[tokio::test]
477    async fn embedding_store_store_and_search() {
478        let (store, sqlite) = setup_with_store().await;
479        let cid = sqlite.create_conversation().await.unwrap();
480        let msg_id = sqlite
481            .save_message(cid, "user", "test message")
482            .await
483            .unwrap();
484
485        store
486            .store(
487                msg_id,
488                cid,
489                "user",
490                vec![1.0, 0.0, 0.0, 0.0],
491                MessageKind::Regular,
492                "test-model",
493            )
494            .await
495            .unwrap();
496
497        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
498        assert_eq!(results.len(), 1);
499        assert_eq!(results[0].message_id, msg_id);
500        assert_eq!(results[0].conversation_id, cid);
501        assert!((results[0].score - 1.0).abs() < 0.001);
502    }
503
504    #[tokio::test]
505    async fn embedding_store_has_embedding_false_for_unknown() {
506        let (store, sqlite) = setup_with_store().await;
507        let cid = sqlite.create_conversation().await.unwrap();
508        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
509        assert!(!store.has_embedding(msg_id).await.unwrap());
510    }
511
512    #[tokio::test]
513    async fn embedding_store_has_embedding_true_after_store() {
514        let (store, sqlite) = setup_with_store().await;
515        let cid = sqlite.create_conversation().await.unwrap();
516        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
517
518        store
519            .store(
520                msg_id,
521                cid,
522                "user",
523                vec![0.0, 1.0, 0.0, 0.0],
524                MessageKind::Regular,
525                "test-model",
526            )
527            .await
528            .unwrap();
529
530        assert!(store.has_embedding(msg_id).await.unwrap());
531    }
532
533    #[tokio::test]
534    async fn embedding_store_search_with_conversation_filter() {
535        let (store, sqlite) = setup_with_store().await;
536        let cid1 = sqlite.create_conversation().await.unwrap();
537        let cid2 = sqlite.create_conversation().await.unwrap();
538        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
539        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
540
541        store
542            .store(
543                msg1,
544                cid1,
545                "user",
546                vec![1.0, 0.0, 0.0, 0.0],
547                MessageKind::Regular,
548                "m",
549            )
550            .await
551            .unwrap();
552        store
553            .store(
554                msg2,
555                cid2,
556                "user",
557                vec![1.0, 0.0, 0.0, 0.0],
558                MessageKind::Regular,
559                "m",
560            )
561            .await
562            .unwrap();
563
564        let results = store
565            .search(
566                &[1.0, 0.0, 0.0, 0.0],
567                10,
568                Some(SearchFilter {
569                    conversation_id: Some(cid1),
570                    role: None,
571                }),
572            )
573            .await
574            .unwrap();
575        assert_eq!(results.len(), 1);
576        assert_eq!(results[0].conversation_id, cid1);
577    }
578
579    #[tokio::test]
580    async fn unique_constraint_on_message_and_model() {
581        let (sqlite, pool) = setup().await;
582        let cid = sqlite.create_conversation().await.unwrap();
583        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
584
585        let point_id1 = uuid::Uuid::new_v4().to_string();
586        sqlx::query(
587            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
588             VALUES (?, ?, ?, ?)",
589        )
590        .bind(msg_id)
591        .bind(&point_id1)
592        .bind(768_i64)
593        .bind("qwen3-embedding")
594        .execute(&pool)
595        .await
596        .unwrap();
597
598        let point_id2 = uuid::Uuid::new_v4().to_string();
599        let result = sqlx::query(
600            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
601             VALUES (?, ?, ?, ?)",
602        )
603        .bind(msg_id)
604        .bind(&point_id2)
605        .bind(768_i64)
606        .bind("qwen3-embedding")
607        .execute(&pool)
608        .await;
609
610        assert!(result.is_err());
611    }
612}