Skip to main content

zeph_memory/
embedding_store.rs

1pub use qdrant_client::qdrant::Filter;
2use sqlx::SqlitePool;
3
4use crate::error::MemoryError;
5use crate::qdrant_ops::QdrantOps;
6use crate::types::{ConversationId, MessageId};
7use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
8
9/// Distinguishes regular messages from summaries when storing embeddings.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MessageKind {
12    Regular,
13    Summary,
14}
15
16impl MessageKind {
17    #[must_use]
18    pub fn is_summary(self) -> bool {
19        matches!(self, Self::Summary)
20    }
21}
22
23const COLLECTION_NAME: &str = "zeph_conversations";
24
25/// Ensure a Qdrant collection exists with cosine distance vectors.
26///
27/// Idempotent: no-op if the collection already exists.
28///
29/// # Errors
30///
31/// Returns an error if Qdrant cannot be reached or collection creation fails.
32pub async fn ensure_qdrant_collection(
33    ops: &QdrantOps,
34    collection: &str,
35    vector_size: u64,
36) -> Result<(), Box<qdrant_client::QdrantError>> {
37    ops.ensure_collection(collection, vector_size).await
38}
39
40pub struct EmbeddingStore {
41    ops: Box<dyn VectorStore>,
42    collection: String,
43    pool: SqlitePool,
44}
45
46impl std::fmt::Debug for EmbeddingStore {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("EmbeddingStore")
49            .field("collection", &self.collection)
50            .finish_non_exhaustive()
51    }
52}
53
54#[derive(Debug)]
55pub struct SearchFilter {
56    pub conversation_id: Option<ConversationId>,
57    pub role: Option<String>,
58}
59
60#[derive(Debug)]
61pub struct SearchResult {
62    pub message_id: MessageId,
63    pub conversation_id: ConversationId,
64    pub score: f32,
65}
66
67impl EmbeddingStore {
68    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
69    ///
70    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
71    /// table (which must already exist via sqlx migrations).
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the Qdrant client cannot be created.
76    pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
77        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
78
79        Ok(Self {
80            ops: Box::new(ops),
81            collection: COLLECTION_NAME.into(),
82            pool,
83        })
84    }
85
86    #[must_use]
87    pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
88        Self {
89            ops: store,
90            collection: COLLECTION_NAME.into(),
91            pool,
92        }
93    }
94
95    /// Ensure the collection exists in Qdrant with the given vector size.
96    ///
97    /// Idempotent: no-op if the collection already exists.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if Qdrant cannot be reached or collection creation fails.
102    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
103        self.ops
104            .ensure_collection(&self.collection, vector_size)
105            .await?;
106        Ok(())
107    }
108
109    /// Store a vector in Qdrant and persist metadata to `SQLite`.
110    ///
111    /// Returns the UUID of the newly created Qdrant point.
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
116    pub async fn store(
117        &self,
118        message_id: MessageId,
119        conversation_id: ConversationId,
120        role: &str,
121        vector: Vec<f32>,
122        kind: MessageKind,
123        model: &str,
124    ) -> Result<String, MemoryError> {
125        let point_id = uuid::Uuid::new_v4().to_string();
126        let dimensions = i64::try_from(vector.len())?;
127
128        let payload = std::collections::HashMap::from([
129            ("message_id".to_owned(), serde_json::json!(message_id.0)),
130            (
131                "conversation_id".to_owned(),
132                serde_json::json!(conversation_id.0),
133            ),
134            ("role".to_owned(), serde_json::json!(role)),
135            (
136                "is_summary".to_owned(),
137                serde_json::json!(kind.is_summary()),
138            ),
139        ]);
140
141        let point = VectorPoint {
142            id: point_id.clone(),
143            vector,
144            payload,
145        };
146
147        self.ops.upsert(&self.collection, vec![point]).await?;
148
149        sqlx::query(
150            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
151             VALUES (?, ?, ?, ?) \
152             ON CONFLICT(message_id, model) DO UPDATE SET \
153             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
154        )
155        .bind(message_id)
156        .bind(&point_id)
157        .bind(dimensions)
158        .bind(model)
159        .execute(&self.pool)
160        .await?;
161
162        Ok(point_id)
163    }
164
165    /// Search for similar vectors in Qdrant, returning up to `limit` results.
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if the Qdrant search fails.
170    pub async fn search(
171        &self,
172        query_vector: &[f32],
173        limit: usize,
174        filter: Option<SearchFilter>,
175    ) -> Result<Vec<SearchResult>, MemoryError> {
176        let limit_u64 = u64::try_from(limit)?;
177
178        let vector_filter = filter.as_ref().and_then(|f| {
179            let mut must = Vec::new();
180            if let Some(cid) = f.conversation_id {
181                must.push(FieldCondition {
182                    field: "conversation_id".into(),
183                    value: FieldValue::Integer(cid.0),
184                });
185            }
186            if let Some(ref role) = f.role {
187                must.push(FieldCondition {
188                    field: "role".into(),
189                    value: FieldValue::Text(role.clone()),
190                });
191            }
192            if must.is_empty() {
193                None
194            } else {
195                Some(VectorFilter {
196                    must,
197                    must_not: vec![],
198                })
199            }
200        });
201
202        let results = self
203            .ops
204            .search(
205                &self.collection,
206                query_vector.to_vec(),
207                limit_u64,
208                vector_filter,
209            )
210            .await?;
211
212        let search_results = results
213            .into_iter()
214            .filter_map(|point| {
215                let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
216                let conversation_id =
217                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
218                Some(SearchResult {
219                    message_id,
220                    conversation_id,
221                    score: point.score,
222                })
223            })
224            .collect();
225
226        Ok(search_results)
227    }
228
229    /// Ensure a named collection exists in Qdrant with the given vector size.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if Qdrant cannot be reached or collection creation fails.
234    pub async fn ensure_named_collection(
235        &self,
236        name: &str,
237        vector_size: u64,
238    ) -> Result<(), MemoryError> {
239        self.ops.ensure_collection(name, vector_size).await?;
240        Ok(())
241    }
242
243    /// Store a vector in a named Qdrant collection with arbitrary payload.
244    ///
245    /// Returns the UUID of the newly created point.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if the Qdrant upsert fails.
250    pub async fn store_to_collection(
251        &self,
252        collection: &str,
253        payload: serde_json::Value,
254        vector: Vec<f32>,
255    ) -> Result<String, MemoryError> {
256        let point_id = uuid::Uuid::new_v4().to_string();
257        let payload_map: std::collections::HashMap<String, serde_json::Value> =
258            serde_json::from_value(payload)?;
259        let point = VectorPoint {
260            id: point_id.clone(),
261            vector,
262            payload: payload_map,
263        };
264        self.ops.upsert(collection, vec![point]).await?;
265        Ok(point_id)
266    }
267
268    /// Search a named Qdrant collection, returning scored points with payloads.
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if the Qdrant search fails.
273    pub async fn search_collection(
274        &self,
275        collection: &str,
276        query_vector: &[f32],
277        limit: usize,
278        filter: Option<VectorFilter>,
279    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
280        let limit_u64 = u64::try_from(limit)?;
281        let results = self
282            .ops
283            .search(collection, query_vector.to_vec(), limit_u64, filter)
284            .await?;
285        Ok(results)
286    }
287
288    /// Check whether an embedding already exists for the given message ID.
289    ///
290    /// # Errors
291    ///
292    /// Returns an error if the `SQLite` query fails.
293    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
294        let row: (i64,) =
295            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
296                .bind(message_id)
297                .fetch_one(&self.pool)
298                .await?;
299
300        Ok(row.0 > 0)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::in_memory_store::InMemoryVectorStore;
308    use crate::sqlite::SqliteStore;
309
310    async fn setup() -> (SqliteStore, SqlitePool) {
311        let store = SqliteStore::new(":memory:").await.unwrap();
312        let pool = store.pool().clone();
313        (store, pool)
314    }
315
316    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
317        let sqlite = SqliteStore::new(":memory:").await.unwrap();
318        let pool = sqlite.pool().clone();
319        let mem_store = Box::new(InMemoryVectorStore::new());
320        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
321        // Create collection first
322        embedding_store.ensure_collection(4).await.unwrap();
323        (embedding_store, sqlite)
324    }
325
326    #[tokio::test]
327    async fn has_embedding_returns_false_when_none() {
328        let (_store, pool) = setup().await;
329
330        let row: (i64,) =
331            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
332                .bind(999_i64)
333                .fetch_one(&pool)
334                .await
335                .unwrap();
336
337        assert_eq!(row.0, 0);
338    }
339
340    #[tokio::test]
341    async fn insert_and_query_embeddings_metadata() {
342        let (sqlite, pool) = setup().await;
343        let cid = sqlite.create_conversation().await.unwrap();
344        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
345
346        let point_id = uuid::Uuid::new_v4().to_string();
347        sqlx::query(
348            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
349             VALUES (?, ?, ?, ?)",
350        )
351        .bind(msg_id)
352        .bind(&point_id)
353        .bind(768_i64)
354        .bind("qwen3-embedding")
355        .execute(&pool)
356        .await
357        .unwrap();
358
359        let row: (i64,) =
360            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
361                .bind(msg_id)
362                .fetch_one(&pool)
363                .await
364                .unwrap();
365        assert_eq!(row.0, 1);
366    }
367
368    #[tokio::test]
369    async fn embedding_store_search_empty_returns_empty() {
370        let (store, _sqlite) = setup_with_store().await;
371        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
372        assert!(results.is_empty());
373    }
374
375    #[tokio::test]
376    async fn embedding_store_store_and_search() {
377        let (store, sqlite) = setup_with_store().await;
378        let cid = sqlite.create_conversation().await.unwrap();
379        let msg_id = sqlite
380            .save_message(cid, "user", "test message")
381            .await
382            .unwrap();
383
384        store
385            .store(
386                msg_id,
387                cid,
388                "user",
389                vec![1.0, 0.0, 0.0, 0.0],
390                MessageKind::Regular,
391                "test-model",
392            )
393            .await
394            .unwrap();
395
396        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
397        assert_eq!(results.len(), 1);
398        assert_eq!(results[0].message_id, msg_id);
399        assert_eq!(results[0].conversation_id, cid);
400        assert!((results[0].score - 1.0).abs() < 0.001);
401    }
402
403    #[tokio::test]
404    async fn embedding_store_has_embedding_false_for_unknown() {
405        let (store, sqlite) = setup_with_store().await;
406        let cid = sqlite.create_conversation().await.unwrap();
407        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
408        assert!(!store.has_embedding(msg_id).await.unwrap());
409    }
410
411    #[tokio::test]
412    async fn embedding_store_has_embedding_true_after_store() {
413        let (store, sqlite) = setup_with_store().await;
414        let cid = sqlite.create_conversation().await.unwrap();
415        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
416
417        store
418            .store(
419                msg_id,
420                cid,
421                "user",
422                vec![0.0, 1.0, 0.0, 0.0],
423                MessageKind::Regular,
424                "test-model",
425            )
426            .await
427            .unwrap();
428
429        assert!(store.has_embedding(msg_id).await.unwrap());
430    }
431
432    #[tokio::test]
433    async fn embedding_store_search_with_conversation_filter() {
434        let (store, sqlite) = setup_with_store().await;
435        let cid1 = sqlite.create_conversation().await.unwrap();
436        let cid2 = sqlite.create_conversation().await.unwrap();
437        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
438        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
439
440        store
441            .store(
442                msg1,
443                cid1,
444                "user",
445                vec![1.0, 0.0, 0.0, 0.0],
446                MessageKind::Regular,
447                "m",
448            )
449            .await
450            .unwrap();
451        store
452            .store(
453                msg2,
454                cid2,
455                "user",
456                vec![1.0, 0.0, 0.0, 0.0],
457                MessageKind::Regular,
458                "m",
459            )
460            .await
461            .unwrap();
462
463        let results = store
464            .search(
465                &[1.0, 0.0, 0.0, 0.0],
466                10,
467                Some(SearchFilter {
468                    conversation_id: Some(cid1),
469                    role: None,
470                }),
471            )
472            .await
473            .unwrap();
474        assert_eq!(results.len(), 1);
475        assert_eq!(results[0].conversation_id, cid1);
476    }
477
478    #[tokio::test]
479    async fn unique_constraint_on_message_and_model() {
480        let (sqlite, pool) = setup().await;
481        let cid = sqlite.create_conversation().await.unwrap();
482        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
483
484        let point_id1 = uuid::Uuid::new_v4().to_string();
485        sqlx::query(
486            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
487             VALUES (?, ?, ?, ?)",
488        )
489        .bind(msg_id)
490        .bind(&point_id1)
491        .bind(768_i64)
492        .bind("qwen3-embedding")
493        .execute(&pool)
494        .await
495        .unwrap();
496
497        let point_id2 = uuid::Uuid::new_v4().to_string();
498        let result = sqlx::query(
499            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
500             VALUES (?, ?, ?, ?)",
501        )
502        .bind(msg_id)
503        .bind(&point_id2)
504        .bind(768_i64)
505        .bind("qwen3-embedding")
506        .execute(&pool)
507        .await;
508
509        assert!(result.is_err());
510    }
511}