Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15/// Distinguishes regular messages from summaries when storing embeddings.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18    Regular,
19    Summary,
20}
21
22impl MessageKind {
23    #[must_use]
24    pub fn is_summary(self) -> bool {
25        matches!(self, Self::Summary)
26    }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31/// Ensure a Qdrant collection exists with cosine distance vectors.
32///
33/// Idempotent: no-op if the collection already exists.
34///
35/// # Errors
36///
37/// Returns an error if Qdrant cannot be reached or collection creation fails.
38pub async fn ensure_qdrant_collection(
39    ops: &QdrantOps,
40    collection: &str,
41    vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43    ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47    ops: Box<dyn VectorStore>,
48    collection: String,
49    pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("EmbeddingStore")
55            .field("collection", &self.collection)
56            .finish_non_exhaustive()
57    }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62    pub conversation_id: Option<ConversationId>,
63    pub role: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct SearchResult {
68    pub message_id: MessageId,
69    pub conversation_id: ConversationId,
70    pub score: f32,
71}
72
73impl EmbeddingStore {
74    /// Create a new `EmbeddingStore` connected to the given Qdrant URL.
75    ///
76    /// The `pool` is used for `SQLite` metadata operations on the `embeddings_metadata`
77    /// table (which must already exist via sqlx migrations).
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the Qdrant client cannot be created.
82    pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
83        let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
84
85        Ok(Self {
86            ops: Box::new(ops),
87            collection: COLLECTION_NAME.into(),
88            pool,
89        })
90    }
91
92    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
93    ///
94    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
95    #[must_use]
96    pub fn new_sqlite(pool: DbPool) -> Self {
97        let ops = DbVectorStore::new(pool.clone());
98        Self {
99            ops: Box::new(ops),
100            collection: COLLECTION_NAME.into(),
101            pool,
102        }
103    }
104
105    #[must_use]
106    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
107        Self {
108            ops: store,
109            collection: COLLECTION_NAME.into(),
110            pool,
111        }
112    }
113
114    pub async fn health_check(&self) -> bool {
115        self.ops.health_check().await.unwrap_or(false)
116    }
117
118    /// Ensure the collection exists in Qdrant with the given vector size.
119    ///
120    /// Idempotent: no-op if the collection already exists.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if Qdrant cannot be reached or collection creation fails.
125    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
126        self.ops
127            .ensure_collection(&self.collection, vector_size)
128            .await?;
129        Ok(())
130    }
131
132    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
133    ///
134    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
135    /// alongside the standard fields. This allows filtering and scoring by tool context
136    /// without corrupting the embedding vector with text prefixes.
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
141    #[allow(clippy::too_many_arguments)]
142    pub async fn store_with_tool_context(
143        &self,
144        message_id: MessageId,
145        conversation_id: ConversationId,
146        role: &str,
147        vector: Vec<f32>,
148        kind: MessageKind,
149        model: &str,
150        chunk_index: u32,
151        tool_name: &str,
152        exit_code: Option<i32>,
153        timestamp: Option<&str>,
154    ) -> Result<String, MemoryError> {
155        let point_id = uuid::Uuid::new_v4().to_string();
156        let dimensions = i64::try_from(vector.len())?;
157
158        let mut payload = std::collections::HashMap::from([
159            ("message_id".to_owned(), serde_json::json!(message_id.0)),
160            (
161                "conversation_id".to_owned(),
162                serde_json::json!(conversation_id.0),
163            ),
164            ("role".to_owned(), serde_json::json!(role)),
165            (
166                "is_summary".to_owned(),
167                serde_json::json!(kind.is_summary()),
168            ),
169            ("tool_name".to_owned(), serde_json::json!(tool_name)),
170        ]);
171        if let Some(code) = exit_code {
172            payload.insert("exit_code".to_owned(), serde_json::json!(code));
173        }
174        if let Some(ts) = timestamp {
175            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
176        }
177
178        let point = VectorPoint {
179            id: point_id.clone(),
180            vector,
181            payload,
182        };
183
184        self.ops.upsert(&self.collection, vec![point]).await?;
185
186        let chunk_index_i64 = i64::from(chunk_index);
187        zeph_db::query(sql!(
188            "INSERT INTO embeddings_metadata \
189             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
190             VALUES (?, ?, ?, ?, ?) \
191             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
192             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
193        ))
194        .bind(message_id)
195        .bind(chunk_index_i64)
196        .bind(&point_id)
197        .bind(dimensions)
198        .bind(model)
199        .execute(&self.pool)
200        .await?;
201
202        Ok(point_id)
203    }
204
205    /// Store a vector in Qdrant and persist metadata to `SQLite`.
206    ///
207    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
208    /// when a long message is split into multiple embeddings.
209    ///
210    /// Returns the UUID of the newly created Qdrant point.
211    ///
212    /// # Errors
213    ///
214    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
215    #[allow(clippy::too_many_arguments)]
216    pub async fn store(
217        &self,
218        message_id: MessageId,
219        conversation_id: ConversationId,
220        role: &str,
221        vector: Vec<f32>,
222        kind: MessageKind,
223        model: &str,
224        chunk_index: u32,
225    ) -> Result<String, MemoryError> {
226        let point_id = uuid::Uuid::new_v4().to_string();
227        let dimensions = i64::try_from(vector.len())?;
228
229        let payload = std::collections::HashMap::from([
230            ("message_id".to_owned(), serde_json::json!(message_id.0)),
231            (
232                "conversation_id".to_owned(),
233                serde_json::json!(conversation_id.0),
234            ),
235            ("role".to_owned(), serde_json::json!(role)),
236            (
237                "is_summary".to_owned(),
238                serde_json::json!(kind.is_summary()),
239            ),
240        ]);
241
242        let point = VectorPoint {
243            id: point_id.clone(),
244            vector,
245            payload,
246        };
247
248        self.ops.upsert(&self.collection, vec![point]).await?;
249
250        let chunk_index_i64 = i64::from(chunk_index);
251        zeph_db::query(sql!(
252            "INSERT INTO embeddings_metadata \
253             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
254             VALUES (?, ?, ?, ?, ?) \
255             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
256             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
257        ))
258        .bind(message_id)
259        .bind(chunk_index_i64)
260        .bind(&point_id)
261        .bind(dimensions)
262        .bind(model)
263        .execute(&self.pool)
264        .await?;
265
266        Ok(point_id)
267    }
268
269    /// Search for similar vectors in Qdrant, returning up to `limit` results.
270    ///
271    /// # Errors
272    ///
273    /// Returns an error if the Qdrant search fails.
274    pub async fn search(
275        &self,
276        query_vector: &[f32],
277        limit: usize,
278        filter: Option<SearchFilter>,
279    ) -> Result<Vec<SearchResult>, MemoryError> {
280        let limit_u64 = u64::try_from(limit)?;
281
282        let vector_filter = filter.as_ref().and_then(|f| {
283            let mut must = Vec::new();
284            if let Some(cid) = f.conversation_id {
285                must.push(FieldCondition {
286                    field: "conversation_id".into(),
287                    value: FieldValue::Integer(cid.0),
288                });
289            }
290            if let Some(ref role) = f.role {
291                must.push(FieldCondition {
292                    field: "role".into(),
293                    value: FieldValue::Text(role.clone()),
294                });
295            }
296            if must.is_empty() {
297                None
298            } else {
299                Some(VectorFilter {
300                    must,
301                    must_not: vec![],
302                })
303            }
304        });
305
306        let results = self
307            .ops
308            .search(
309                &self.collection,
310                query_vector.to_vec(),
311                limit_u64,
312                vector_filter,
313            )
314            .await?;
315
316        // Deduplicate by message_id, keeping the chunk with the highest score.
317        // A single message may produce multiple Qdrant points (one per chunk).
318        let mut best: std::collections::HashMap<MessageId, SearchResult> =
319            std::collections::HashMap::new();
320        for point in results {
321            let Some(message_id) = point
322                .payload
323                .get("message_id")
324                .and_then(serde_json::Value::as_i64)
325            else {
326                continue;
327            };
328            let Some(conversation_id) = point
329                .payload
330                .get("conversation_id")
331                .and_then(serde_json::Value::as_i64)
332            else {
333                continue;
334            };
335            let message_id = MessageId(message_id);
336            let entry = best.entry(message_id).or_insert(SearchResult {
337                message_id,
338                conversation_id: ConversationId(conversation_id),
339                score: f32::NEG_INFINITY,
340            });
341            if point.score > entry.score {
342                entry.score = point.score;
343            }
344        }
345
346        let mut search_results: Vec<SearchResult> = best.into_values().collect();
347        search_results.sort_by(|a, b| {
348            b.score
349                .partial_cmp(&a.score)
350                .unwrap_or(std::cmp::Ordering::Equal)
351        });
352        search_results.truncate(limit);
353
354        Ok(search_results)
355    }
356
357    /// Check whether a named collection exists in the vector store.
358    ///
359    /// # Errors
360    ///
361    /// Returns an error if the store backend cannot be reached.
362    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
363        self.ops.collection_exists(name).await.map_err(Into::into)
364    }
365
366    /// Ensure a named collection exists in Qdrant with the given vector size.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if Qdrant cannot be reached or collection creation fails.
371    pub async fn ensure_named_collection(
372        &self,
373        name: &str,
374        vector_size: u64,
375    ) -> Result<(), MemoryError> {
376        self.ops.ensure_collection(name, vector_size).await?;
377        Ok(())
378    }
379
380    /// Store a vector in a named Qdrant collection with arbitrary payload.
381    ///
382    /// Returns the UUID of the newly created point.
383    ///
384    /// # Errors
385    ///
386    /// Returns an error if the Qdrant upsert fails.
387    pub async fn store_to_collection(
388        &self,
389        collection: &str,
390        payload: serde_json::Value,
391        vector: Vec<f32>,
392    ) -> Result<String, MemoryError> {
393        let point_id = uuid::Uuid::new_v4().to_string();
394        let payload_map: std::collections::HashMap<String, serde_json::Value> =
395            serde_json::from_value(payload)?;
396        let point = VectorPoint {
397            id: point_id.clone(),
398            vector,
399            payload: payload_map,
400        };
401        self.ops.upsert(collection, vec![point]).await?;
402        Ok(point_id)
403    }
404
405    /// Upsert a vector into a named collection, reusing an existing point ID.
406    ///
407    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
408    ///
409    /// # Errors
410    ///
411    /// Returns an error if the Qdrant upsert fails.
412    pub async fn upsert_to_collection(
413        &self,
414        collection: &str,
415        point_id: &str,
416        payload: serde_json::Value,
417        vector: Vec<f32>,
418    ) -> Result<(), MemoryError> {
419        let payload_map: std::collections::HashMap<String, serde_json::Value> =
420            serde_json::from_value(payload)?;
421        let point = VectorPoint {
422            id: point_id.to_owned(),
423            vector,
424            payload: payload_map,
425        };
426        self.ops.upsert(collection, vec![point]).await?;
427        Ok(())
428    }
429
430    /// Search a named Qdrant collection, returning scored points with payloads.
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if the Qdrant search fails.
435    pub async fn search_collection(
436        &self,
437        collection: &str,
438        query_vector: &[f32],
439        limit: usize,
440        filter: Option<VectorFilter>,
441    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
442        let limit_u64 = u64::try_from(limit)?;
443        let results = self
444            .ops
445            .search(collection, query_vector.to_vec(), limit_u64, filter)
446            .await?;
447        Ok(results)
448    }
449
450    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
451    ///
452    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
453    ///
454    /// # Errors
455    ///
456    /// Returns an error if the `SQLite` query fails.
457    pub async fn get_vectors(
458        &self,
459        ids: &[MessageId],
460    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
461        if ids.is_empty() {
462            return Ok(std::collections::HashMap::new());
463        }
464
465        let placeholders = zeph_db::placeholder_list(1, ids.len());
466        let query = format!(
467            "SELECT em.message_id, vp.vector \
468             FROM embeddings_metadata em \
469             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
470             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
471        );
472        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
473        for &id in ids {
474            q = q.bind(id);
475        }
476
477        let rows = q.fetch_all(&self.pool).await?;
478
479        let map = rows
480            .into_iter()
481            .filter_map(|(msg_id, blob)| {
482                if blob.len() % 4 != 0 {
483                    return None;
484                }
485                let vec: Vec<f32> = blob
486                    .chunks_exact(4)
487                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
488                    .collect();
489                Some((msg_id, vec))
490            })
491            .collect();
492
493        Ok(map)
494    }
495
496    /// Check whether an embedding already exists for the given message ID.
497    ///
498    /// # Errors
499    ///
500    /// Returns an error if the `SQLite` query fails.
501    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
502        let row: (i64,) = zeph_db::query_as(sql!(
503            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
504        ))
505        .bind(message_id)
506        .fetch_one(&self.pool)
507        .await?;
508
509        Ok(row.0 > 0)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::in_memory_store::InMemoryVectorStore;
517    use crate::store::SqliteStore;
518
519    async fn setup() -> (SqliteStore, DbPool) {
520        let store = SqliteStore::new(":memory:").await.unwrap();
521        let pool = store.pool().clone();
522        (store, pool)
523    }
524
525    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
526        let sqlite = SqliteStore::new(":memory:").await.unwrap();
527        let pool = sqlite.pool().clone();
528        let mem_store = Box::new(InMemoryVectorStore::new());
529        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
530        // Create collection first
531        embedding_store.ensure_collection(4).await.unwrap();
532        (embedding_store, sqlite)
533    }
534
535    #[tokio::test]
536    async fn has_embedding_returns_false_when_none() {
537        let (_store, pool) = setup().await;
538
539        let row: (i64,) = zeph_db::query_as(sql!(
540            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
541        ))
542        .bind(999_i64)
543        .fetch_one(&pool)
544        .await
545        .unwrap();
546
547        assert_eq!(row.0, 0);
548    }
549
550    #[tokio::test]
551    async fn insert_and_query_embeddings_metadata() {
552        let (sqlite, pool) = setup().await;
553        let cid = sqlite.create_conversation().await.unwrap();
554        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
555
556        let point_id = uuid::Uuid::new_v4().to_string();
557        zeph_db::query(sql!(
558            "INSERT INTO embeddings_metadata \
559             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
560             VALUES (?, ?, ?, ?, ?)"
561        ))
562        .bind(msg_id)
563        .bind(0_i64)
564        .bind(&point_id)
565        .bind(768_i64)
566        .bind("qwen3-embedding")
567        .execute(&pool)
568        .await
569        .unwrap();
570
571        let row: (i64,) = zeph_db::query_as(sql!(
572            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
573        ))
574        .bind(msg_id)
575        .fetch_one(&pool)
576        .await
577        .unwrap();
578        assert_eq!(row.0, 1);
579    }
580
581    #[tokio::test]
582    async fn embedding_store_search_empty_returns_empty() {
583        let (store, _sqlite) = setup_with_store().await;
584        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
585        assert!(results.is_empty());
586    }
587
588    #[tokio::test]
589    async fn embedding_store_store_and_search() {
590        let (store, sqlite) = setup_with_store().await;
591        let cid = sqlite.create_conversation().await.unwrap();
592        let msg_id = sqlite
593            .save_message(cid, "user", "test message")
594            .await
595            .unwrap();
596
597        store
598            .store(
599                msg_id,
600                cid,
601                "user",
602                vec![1.0, 0.0, 0.0, 0.0],
603                MessageKind::Regular,
604                "test-model",
605                0,
606            )
607            .await
608            .unwrap();
609
610        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
611        assert_eq!(results.len(), 1);
612        assert_eq!(results[0].message_id, msg_id);
613        assert_eq!(results[0].conversation_id, cid);
614        assert!((results[0].score - 1.0).abs() < 0.001);
615    }
616
617    #[tokio::test]
618    async fn embedding_store_has_embedding_false_for_unknown() {
619        let (store, sqlite) = setup_with_store().await;
620        let cid = sqlite.create_conversation().await.unwrap();
621        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
622        assert!(!store.has_embedding(msg_id).await.unwrap());
623    }
624
625    #[tokio::test]
626    async fn embedding_store_has_embedding_true_after_store() {
627        let (store, sqlite) = setup_with_store().await;
628        let cid = sqlite.create_conversation().await.unwrap();
629        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
630
631        store
632            .store(
633                msg_id,
634                cid,
635                "user",
636                vec![0.0, 1.0, 0.0, 0.0],
637                MessageKind::Regular,
638                "test-model",
639                0,
640            )
641            .await
642            .unwrap();
643
644        assert!(store.has_embedding(msg_id).await.unwrap());
645    }
646
647    #[tokio::test]
648    async fn embedding_store_search_with_conversation_filter() {
649        let (store, sqlite) = setup_with_store().await;
650        let cid1 = sqlite.create_conversation().await.unwrap();
651        let cid2 = sqlite.create_conversation().await.unwrap();
652        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
653        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
654
655        store
656            .store(
657                msg1,
658                cid1,
659                "user",
660                vec![1.0, 0.0, 0.0, 0.0],
661                MessageKind::Regular,
662                "m",
663                0,
664            )
665            .await
666            .unwrap();
667        store
668            .store(
669                msg2,
670                cid2,
671                "user",
672                vec![1.0, 0.0, 0.0, 0.0],
673                MessageKind::Regular,
674                "m",
675                0,
676            )
677            .await
678            .unwrap();
679
680        let results = store
681            .search(
682                &[1.0, 0.0, 0.0, 0.0],
683                10,
684                Some(SearchFilter {
685                    conversation_id: Some(cid1),
686                    role: None,
687                }),
688            )
689            .await
690            .unwrap();
691        assert_eq!(results.len(), 1);
692        assert_eq!(results[0].conversation_id, cid1);
693    }
694
695    #[tokio::test]
696    async fn unique_constraint_on_message_chunk_and_model() {
697        let (sqlite, pool) = setup().await;
698        let cid = sqlite.create_conversation().await.unwrap();
699        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
700
701        let point_id1 = uuid::Uuid::new_v4().to_string();
702        zeph_db::query(sql!(
703            "INSERT INTO embeddings_metadata \
704             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
705             VALUES (?, ?, ?, ?, ?)"
706        ))
707        .bind(msg_id)
708        .bind(0_i64)
709        .bind(&point_id1)
710        .bind(768_i64)
711        .bind("qwen3-embedding")
712        .execute(&pool)
713        .await
714        .unwrap();
715
716        // Same (message_id, chunk_index, model) — must fail.
717        let point_id2 = uuid::Uuid::new_v4().to_string();
718        let result = zeph_db::query(sql!(
719            "INSERT INTO embeddings_metadata \
720             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
721             VALUES (?, ?, ?, ?, ?)"
722        ))
723        .bind(msg_id)
724        .bind(0_i64)
725        .bind(&point_id2)
726        .bind(768_i64)
727        .bind("qwen3-embedding")
728        .execute(&pool)
729        .await;
730        assert!(result.is_err());
731
732        // Different chunk_index — must succeed.
733        let point_id3 = uuid::Uuid::new_v4().to_string();
734        zeph_db::query(sql!(
735            "INSERT INTO embeddings_metadata \
736             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
737             VALUES (?, ?, ?, ?, ?)"
738        ))
739        .bind(msg_id)
740        .bind(1_i64)
741        .bind(&point_id3)
742        .bind(768_i64)
743        .bind("qwen3-embedding")
744        .execute(&pool)
745        .await
746        .unwrap();
747    }
748}