Skip to main content

zeph_memory/
embedding_store.rs

1pub use qdrant_client::qdrant::Filter;
2use sqlx::SqlitePool;
3
4use crate::error::MemoryError;
5use crate::qdrant_ops::QdrantOps;
6use crate::sqlite_vector_store::SqliteVectorStore;
7use crate::types::{ConversationId, MessageId};
8use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
9
10/// Distinguishes regular messages from summaries when storing embeddings.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MessageKind {
13    Regular,
14    Summary,
15}
16
17impl MessageKind {
18    #[must_use]
19    pub fn is_summary(self) -> bool {
20        matches!(self, Self::Summary)
21    }
22}
23
24const COLLECTION_NAME: &str = "zeph_conversations";
25
26/// Ensure a Qdrant collection exists with cosine distance vectors.
27///
28/// Idempotent: no-op if the collection already exists.
29///
30/// # Errors
31///
32/// Returns an error if Qdrant cannot be reached or collection creation fails.
33pub async fn ensure_qdrant_collection(
34    ops: &QdrantOps,
35    collection: &str,
36    vector_size: u64,
37) -> Result<(), Box<qdrant_client::QdrantError>> {
38    ops.ensure_collection(collection, vector_size).await
39}
40
41pub struct EmbeddingStore {
42    ops: Box<dyn VectorStore>,
43    collection: String,
44    pool: SqlitePool,
45}
46
47impl std::fmt::Debug for EmbeddingStore {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("EmbeddingStore")
50            .field("collection", &self.collection)
51            .finish_non_exhaustive()
52    }
53}
54
55#[derive(Debug)]
56pub struct SearchFilter {
57    pub conversation_id: Option<ConversationId>,
58    pub role: Option<String>,
59}
60
61#[derive(Debug)]
62pub struct SearchResult {
63    pub message_id: MessageId,
64    pub conversation_id: ConversationId,
65    pub score: f32,
66}
67
68impl EmbeddingStore {
69    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
70    ///
71    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
72    /// table (which must already exist via sqlx migrations).
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the Qdrant client cannot be created.
77    pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
78        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
79
80        Ok(Self {
81            ops: Box::new(ops),
82            collection: COLLECTION_NAME.into(),
83            pool,
84        })
85    }
86
87    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
88    ///
89    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
90    #[must_use]
91    pub fn new_sqlite(pool: SqlitePool) -> Self {
92        let ops = SqliteVectorStore::new(pool.clone());
93        Self {
94            ops: Box::new(ops),
95            collection: COLLECTION_NAME.into(),
96            pool,
97        }
98    }
99
100    #[must_use]
101    pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
102        Self {
103            ops: store,
104            collection: COLLECTION_NAME.into(),
105            pool,
106        }
107    }
108
109    pub async fn health_check(&self) -> bool {
110        self.ops.health_check().await.unwrap_or(false)
111    }
112
113    /// Ensure the collection exists in Qdrant with the given vector size.
114    ///
115    /// Idempotent: no-op if the collection already exists.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if Qdrant cannot be reached or collection creation fails.
120    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
121        self.ops
122            .ensure_collection(&self.collection, vector_size)
123            .await?;
124        Ok(())
125    }
126
127    /// Store a vector in Qdrant and persist metadata to `SQLite`.
128    ///
129    /// Returns the UUID of the newly created Qdrant point.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
134    pub async fn store(
135        &self,
136        message_id: MessageId,
137        conversation_id: ConversationId,
138        role: &str,
139        vector: Vec<f32>,
140        kind: MessageKind,
141        model: &str,
142    ) -> Result<String, MemoryError> {
143        let point_id = uuid::Uuid::new_v4().to_string();
144        let dimensions = i64::try_from(vector.len())?;
145
146        let payload = std::collections::HashMap::from([
147            ("message_id".to_owned(), serde_json::json!(message_id.0)),
148            (
149                "conversation_id".to_owned(),
150                serde_json::json!(conversation_id.0),
151            ),
152            ("role".to_owned(), serde_json::json!(role)),
153            (
154                "is_summary".to_owned(),
155                serde_json::json!(kind.is_summary()),
156            ),
157        ]);
158
159        let point = VectorPoint {
160            id: point_id.clone(),
161            vector,
162            payload,
163        };
164
165        self.ops.upsert(&self.collection, vec![point]).await?;
166
167        sqlx::query(
168            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
169             VALUES (?, ?, ?, ?) \
170             ON CONFLICT(message_id, model) DO UPDATE SET \
171             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
172        )
173        .bind(message_id)
174        .bind(&point_id)
175        .bind(dimensions)
176        .bind(model)
177        .execute(&self.pool)
178        .await?;
179
180        Ok(point_id)
181    }
182
183    /// Search for similar vectors in Qdrant, returning up to `limit` results.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the Qdrant search fails.
188    pub async fn search(
189        &self,
190        query_vector: &[f32],
191        limit: usize,
192        filter: Option<SearchFilter>,
193    ) -> Result<Vec<SearchResult>, MemoryError> {
194        let limit_u64 = u64::try_from(limit)?;
195
196        let vector_filter = filter.as_ref().and_then(|f| {
197            let mut must = Vec::new();
198            if let Some(cid) = f.conversation_id {
199                must.push(FieldCondition {
200                    field: "conversation_id".into(),
201                    value: FieldValue::Integer(cid.0),
202                });
203            }
204            if let Some(ref role) = f.role {
205                must.push(FieldCondition {
206                    field: "role".into(),
207                    value: FieldValue::Text(role.clone()),
208                });
209            }
210            if must.is_empty() {
211                None
212            } else {
213                Some(VectorFilter {
214                    must,
215                    must_not: vec![],
216                })
217            }
218        });
219
220        let results = self
221            .ops
222            .search(
223                &self.collection,
224                query_vector.to_vec(),
225                limit_u64,
226                vector_filter,
227            )
228            .await?;
229
230        let search_results = results
231            .into_iter()
232            .filter_map(|point| {
233                let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
234                let conversation_id =
235                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
236                Some(SearchResult {
237                    message_id,
238                    conversation_id,
239                    score: point.score,
240                })
241            })
242            .collect();
243
244        Ok(search_results)
245    }
246
247    /// Ensure a named collection exists in Qdrant with the given vector size.
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if Qdrant cannot be reached or collection creation fails.
252    pub async fn ensure_named_collection(
253        &self,
254        name: &str,
255        vector_size: u64,
256    ) -> Result<(), MemoryError> {
257        self.ops.ensure_collection(name, vector_size).await?;
258        Ok(())
259    }
260
261    /// Store a vector in a named Qdrant collection with arbitrary payload.
262    ///
263    /// Returns the UUID of the newly created point.
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if the Qdrant upsert fails.
268    pub async fn store_to_collection(
269        &self,
270        collection: &str,
271        payload: serde_json::Value,
272        vector: Vec<f32>,
273    ) -> Result<String, MemoryError> {
274        let point_id = uuid::Uuid::new_v4().to_string();
275        let payload_map: std::collections::HashMap<String, serde_json::Value> =
276            serde_json::from_value(payload)?;
277        let point = VectorPoint {
278            id: point_id.clone(),
279            vector,
280            payload: payload_map,
281        };
282        self.ops.upsert(collection, vec![point]).await?;
283        Ok(point_id)
284    }
285
286    /// Search a named Qdrant collection, returning scored points with payloads.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if the Qdrant search fails.
291    pub async fn search_collection(
292        &self,
293        collection: &str,
294        query_vector: &[f32],
295        limit: usize,
296        filter: Option<VectorFilter>,
297    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
298        let limit_u64 = u64::try_from(limit)?;
299        let results = self
300            .ops
301            .search(collection, query_vector.to_vec(), limit_u64, filter)
302            .await?;
303        Ok(results)
304    }
305
306    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
307    ///
308    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
309    ///
310    /// # Errors
311    ///
312    /// Returns an error if the `SQLite` query fails.
313    pub async fn get_vectors(
314        &self,
315        ids: &[MessageId],
316    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
317        if ids.is_empty() {
318            return Ok(std::collections::HashMap::new());
319        }
320
321        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
322        let query = format!(
323            "SELECT em.message_id, vp.vector \
324             FROM embeddings_metadata em \
325             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
326             WHERE em.message_id IN ({placeholders})"
327        );
328        let mut q = sqlx::query_as::<_, (MessageId, Vec<u8>)>(&query);
329        for &id in ids {
330            q = q.bind(id);
331        }
332
333        let rows = q.fetch_all(&self.pool).await.unwrap_or_default();
334
335        let map = rows
336            .into_iter()
337            .filter_map(|(msg_id, blob)| {
338                if blob.len() % 4 != 0 {
339                    return None;
340                }
341                let vec: Vec<f32> = blob
342                    .chunks_exact(4)
343                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
344                    .collect();
345                Some((msg_id, vec))
346            })
347            .collect();
348
349        Ok(map)
350    }
351
352    /// Check whether an embedding already exists for the given message ID.
353    ///
354    /// # Errors
355    ///
356    /// Returns an error if the `SQLite` query fails.
357    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
358        let row: (i64,) =
359            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
360                .bind(message_id)
361                .fetch_one(&self.pool)
362                .await?;
363
364        Ok(row.0 > 0)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::in_memory_store::InMemoryVectorStore;
372    use crate::sqlite::SqliteStore;
373
374    async fn setup() -> (SqliteStore, SqlitePool) {
375        let store = SqliteStore::new(":memory:").await.unwrap();
376        let pool = store.pool().clone();
377        (store, pool)
378    }
379
380    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
381        let sqlite = SqliteStore::new(":memory:").await.unwrap();
382        let pool = sqlite.pool().clone();
383        let mem_store = Box::new(InMemoryVectorStore::new());
384        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
385        // Create collection first
386        embedding_store.ensure_collection(4).await.unwrap();
387        (embedding_store, sqlite)
388    }
389
390    #[tokio::test]
391    async fn has_embedding_returns_false_when_none() {
392        let (_store, pool) = setup().await;
393
394        let row: (i64,) =
395            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
396                .bind(999_i64)
397                .fetch_one(&pool)
398                .await
399                .unwrap();
400
401        assert_eq!(row.0, 0);
402    }
403
404    #[tokio::test]
405    async fn insert_and_query_embeddings_metadata() {
406        let (sqlite, pool) = setup().await;
407        let cid = sqlite.create_conversation().await.unwrap();
408        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
409
410        let point_id = uuid::Uuid::new_v4().to_string();
411        sqlx::query(
412            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
413             VALUES (?, ?, ?, ?)",
414        )
415        .bind(msg_id)
416        .bind(&point_id)
417        .bind(768_i64)
418        .bind("qwen3-embedding")
419        .execute(&pool)
420        .await
421        .unwrap();
422
423        let row: (i64,) =
424            sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
425                .bind(msg_id)
426                .fetch_one(&pool)
427                .await
428                .unwrap();
429        assert_eq!(row.0, 1);
430    }
431
432    #[tokio::test]
433    async fn embedding_store_search_empty_returns_empty() {
434        let (store, _sqlite) = setup_with_store().await;
435        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
436        assert!(results.is_empty());
437    }
438
439    #[tokio::test]
440    async fn embedding_store_store_and_search() {
441        let (store, sqlite) = setup_with_store().await;
442        let cid = sqlite.create_conversation().await.unwrap();
443        let msg_id = sqlite
444            .save_message(cid, "user", "test message")
445            .await
446            .unwrap();
447
448        store
449            .store(
450                msg_id,
451                cid,
452                "user",
453                vec![1.0, 0.0, 0.0, 0.0],
454                MessageKind::Regular,
455                "test-model",
456            )
457            .await
458            .unwrap();
459
460        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
461        assert_eq!(results.len(), 1);
462        assert_eq!(results[0].message_id, msg_id);
463        assert_eq!(results[0].conversation_id, cid);
464        assert!((results[0].score - 1.0).abs() < 0.001);
465    }
466
467    #[tokio::test]
468    async fn embedding_store_has_embedding_false_for_unknown() {
469        let (store, sqlite) = setup_with_store().await;
470        let cid = sqlite.create_conversation().await.unwrap();
471        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
472        assert!(!store.has_embedding(msg_id).await.unwrap());
473    }
474
475    #[tokio::test]
476    async fn embedding_store_has_embedding_true_after_store() {
477        let (store, sqlite) = setup_with_store().await;
478        let cid = sqlite.create_conversation().await.unwrap();
479        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
480
481        store
482            .store(
483                msg_id,
484                cid,
485                "user",
486                vec![0.0, 1.0, 0.0, 0.0],
487                MessageKind::Regular,
488                "test-model",
489            )
490            .await
491            .unwrap();
492
493        assert!(store.has_embedding(msg_id).await.unwrap());
494    }
495
496    #[tokio::test]
497    async fn embedding_store_search_with_conversation_filter() {
498        let (store, sqlite) = setup_with_store().await;
499        let cid1 = sqlite.create_conversation().await.unwrap();
500        let cid2 = sqlite.create_conversation().await.unwrap();
501        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
502        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
503
504        store
505            .store(
506                msg1,
507                cid1,
508                "user",
509                vec![1.0, 0.0, 0.0, 0.0],
510                MessageKind::Regular,
511                "m",
512            )
513            .await
514            .unwrap();
515        store
516            .store(
517                msg2,
518                cid2,
519                "user",
520                vec![1.0, 0.0, 0.0, 0.0],
521                MessageKind::Regular,
522                "m",
523            )
524            .await
525            .unwrap();
526
527        let results = store
528            .search(
529                &[1.0, 0.0, 0.0, 0.0],
530                10,
531                Some(SearchFilter {
532                    conversation_id: Some(cid1),
533                    role: None,
534                }),
535            )
536            .await
537            .unwrap();
538        assert_eq!(results.len(), 1);
539        assert_eq!(results[0].conversation_id, cid1);
540    }
541
542    #[tokio::test]
543    async fn unique_constraint_on_message_and_model() {
544        let (sqlite, pool) = setup().await;
545        let cid = sqlite.create_conversation().await.unwrap();
546        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
547
548        let point_id1 = uuid::Uuid::new_v4().to_string();
549        sqlx::query(
550            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
551             VALUES (?, ?, ?, ?)",
552        )
553        .bind(msg_id)
554        .bind(&point_id1)
555        .bind(768_i64)
556        .bind("qwen3-embedding")
557        .execute(&pool)
558        .await
559        .unwrap();
560
561        let point_id2 = uuid::Uuid::new_v4().to_string();
562        let result = sqlx::query(
563            "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
564             VALUES (?, ?, ?, ?)",
565        )
566        .bind(msg_id)
567        .bind(&point_id2)
568        .bind(768_i64)
569        .bind("qwen3-embedding")
570        .execute(&pool)
571        .await;
572
573        assert!(result.is_err());
574    }
575}