Skip to main content

zeph_memory/
embedding_registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Generic embedding registry backed by Qdrant.
5//!
6//! Provides deduplication through content-hash delta tracking and collection-level
7//! embedding-model change detection.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use qdrant_client::qdrant::{PointStruct, value::Kind};
14
15use crate::QdrantOps;
16use crate::vector_store::VectorStoreError;
17
18/// Boxed future returned by an embedding function.
19pub type EmbedFuture = Pin<
20    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
21>;
22
23/// Trait implemented by domain types that can be stored in an [`EmbeddingRegistry`].
24pub trait Embeddable: Send + Sync {
25    /// Unique string key used for point-ID generation and delta tracking.
26    fn key(&self) -> &str;
27
28    /// blake3 hex hash of all semantically relevant fields.
29    fn content_hash(&self) -> String;
30
31    /// Text that will be embedded (e.g. description).
32    fn embed_text(&self) -> &str;
33
34    /// Full JSON payload to store in Qdrant. **Must** include a `"key"` field
35    /// equal to [`Self::key()`] so [`EmbeddingRegistry`] can recover it on scroll.
36    fn to_payload(&self) -> serde_json::Value;
37}
38
39/// Counters returned by [`EmbeddingRegistry::sync`].
40#[derive(Debug, Default, Clone)]
41pub struct SyncStats {
42    pub added: usize,
43    pub updated: usize,
44    pub removed: usize,
45    pub unchanged: usize,
46}
47
48/// Errors produced by [`EmbeddingRegistry`].
49#[derive(Debug, thiserror::Error)]
50pub enum EmbeddingRegistryError {
51    #[error("vector store error: {0}")]
52    VectorStore(#[from] VectorStoreError),
53
54    #[error("embedding error: {0}")]
55    Embedding(String),
56
57    #[error("serialization error: {0}")]
58    Serialization(String),
59
60    #[error("dimension probe failed: {0}")]
61    DimensionProbe(String),
62}
63
64impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
65    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
66        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
67    }
68}
69
70impl From<serde_json::Error> for EmbeddingRegistryError {
71    fn from(e: serde_json::Error) -> Self {
72        Self::Serialization(e.to_string())
73    }
74}
75
76impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
77    fn from(e: std::num::TryFromIntError) -> Self {
78        Self::DimensionProbe(e.to_string())
79    }
80}
81
82/// Generic Qdrant-backed embedding registry.
83///
84/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
85/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
86/// O(1) delta detection between syncs.
87pub struct EmbeddingRegistry {
88    ops: QdrantOps,
89    collection: String,
90    namespace: uuid::Uuid,
91    hashes: HashMap<String, String>,
92}
93
94impl std::fmt::Debug for EmbeddingRegistry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("EmbeddingRegistry")
97            .field("collection", &self.collection)
98            .finish_non_exhaustive()
99    }
100}
101
102impl EmbeddingRegistry {
103    /// Create a registry wrapping an existing [`QdrantOps`] connection.
104    #[must_use]
105    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
106        Self {
107            ops,
108            collection: collection.into(),
109            namespace,
110            hashes: HashMap::new(),
111        }
112    }
113
114    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
115    /// unnecessary re-embedding.  Re-creates the collection when the embedding
116    /// model changes.
117    ///
118    /// # Errors
119    ///
120    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
121    pub async fn sync<T: Embeddable>(
122        &mut self,
123        items: &[T],
124        embedding_model: &str,
125        embed_fn: impl Fn(&str) -> EmbedFuture,
126    ) -> Result<SyncStats, EmbeddingRegistryError> {
127        let mut stats = SyncStats::default();
128
129        self.ensure_collection(&embed_fn).await?;
130
131        let existing = self
132            .ops
133            .scroll_all(&self.collection, "key")
134            .await
135            .map_err(|e| {
136                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
137            })?;
138
139        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
140        for item in items {
141            current.insert(item.key().to_owned(), (item.content_hash(), item));
142        }
143
144        let model_changed = existing.values().any(|stored| {
145            stored
146                .get("embedding_model")
147                .is_some_and(|m| m != embedding_model)
148        });
149
150        if model_changed {
151            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
152            self.recreate_collection(&embed_fn).await?;
153        }
154
155        let mut points_to_upsert = Vec::new();
156        for (key, (hash, item)) in &current {
157            let needs_update = if let Some(stored) = existing.get(key) {
158                model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
159            } else {
160                true
161            };
162
163            if !needs_update {
164                stats.unchanged += 1;
165                self.hashes.insert(key.clone(), hash.clone());
166                continue;
167            }
168
169            let vector = match embed_fn(item.embed_text()).await {
170                Ok(v) => v,
171                Err(e) => {
172                    tracing::warn!("failed to embed item '{key}': {e:#}");
173                    continue;
174                }
175            };
176
177            let point_id = self.point_id(key);
178            let mut payload = item.to_payload();
179            if let Some(obj) = payload.as_object_mut() {
180                obj.insert(
181                    "content_hash".into(),
182                    serde_json::Value::String(hash.clone()),
183                );
184                obj.insert(
185                    "embedding_model".into(),
186                    serde_json::Value::String(embedding_model.to_owned()),
187                );
188            }
189            let payload_map = QdrantOps::json_to_payload(payload)?;
190
191            points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
192
193            if existing.contains_key(key) {
194                stats.updated += 1;
195            } else {
196                stats.added += 1;
197            }
198            self.hashes.insert(key.clone(), hash.clone());
199        }
200
201        if !points_to_upsert.is_empty() {
202            self.ops
203                .upsert(&self.collection, points_to_upsert)
204                .await
205                .map_err(|e| {
206                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
207                })?;
208        }
209
210        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
211            .keys()
212            .filter(|key| !current.contains_key(*key))
213            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
214            .collect();
215
216        if !orphan_ids.is_empty() {
217            stats.removed = orphan_ids.len();
218            self.ops
219                .delete_by_ids(&self.collection, orphan_ids)
220                .await
221                .map_err(|e| {
222                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
223                })?;
224        }
225
226        tracing::info!(
227            added = stats.added,
228            updated = stats.updated,
229            removed = stats.removed,
230            unchanged = stats.unchanged,
231            collection = &self.collection,
232            "embeddings synced"
233        );
234
235        Ok(stats)
236    }
237
238    /// Search the collection, returning raw scored Qdrant points.
239    ///
240    /// Consumers map the payloads to their domain types.
241    ///
242    /// # Errors
243    ///
244    /// Returns [`EmbeddingRegistryError`] if embedding or Qdrant search fails.
245    pub async fn search_raw(
246        &self,
247        query: &str,
248        limit: usize,
249        embed_fn: impl Fn(&str) -> EmbedFuture,
250    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
251        let query_vec = embed_fn(query)
252            .await
253            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
254
255        let Ok(limit_u64) = u64::try_from(limit) else {
256            return Ok(Vec::new());
257        };
258
259        let results = self
260            .ops
261            .search(&self.collection, query_vec, limit_u64, None)
262            .await
263            .map_err(|e| {
264                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
265            })?;
266
267        let scored: Vec<crate::ScoredVectorPoint> = results
268            .into_iter()
269            .map(|point| {
270                let payload: HashMap<String, serde_json::Value> = point
271                    .payload
272                    .into_iter()
273                    .filter_map(|(k, v)| {
274                        let json_val = match v.kind? {
275                            Kind::StringValue(s) => serde_json::Value::String(s),
276                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
277                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
278                            Kind::DoubleValue(d) => {
279                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
280                            }
281                            _ => return None,
282                        };
283                        Some((k, json_val))
284                    })
285                    .collect();
286
287                let id = match point.id.and_then(|pid| pid.point_id_options) {
288                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
289                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
290                    None => String::new(),
291                };
292
293                crate::ScoredVectorPoint {
294                    id,
295                    score: point.score,
296                    payload,
297                }
298            })
299            .collect();
300
301        Ok(scored)
302    }
303
304    fn point_id(&self, key: &str) -> String {
305        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
306    }
307
308    async fn ensure_collection(
309        &self,
310        embed_fn: &impl Fn(&str) -> EmbedFuture,
311    ) -> Result<(), EmbeddingRegistryError> {
312        if self
313            .ops
314            .collection_exists(&self.collection)
315            .await
316            .map_err(|e| {
317                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
318            })?
319        {
320            return Ok(());
321        }
322
323        let probe = embed_fn("dimension probe")
324            .await
325            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
326        let vector_size = u64::try_from(probe.len())?;
327
328        self.ops
329            .ensure_collection(&self.collection, vector_size)
330            .await
331            .map_err(|e| {
332                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
333            })?;
334
335        tracing::info!(
336            collection = &self.collection,
337            dimensions = vector_size,
338            "created Qdrant collection"
339        );
340
341        Ok(())
342    }
343
344    async fn recreate_collection(
345        &self,
346        embed_fn: &impl Fn(&str) -> EmbedFuture,
347    ) -> Result<(), EmbeddingRegistryError> {
348        if self
349            .ops
350            .collection_exists(&self.collection)
351            .await
352            .map_err(|e| {
353                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
354            })?
355        {
356            self.ops
357                .delete_collection(&self.collection)
358                .await
359                .map_err(|e| {
360                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
361                })?;
362            tracing::info!(
363                collection = &self.collection,
364                "deleted collection for recreation"
365            );
366        }
367        self.ensure_collection(embed_fn).await
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    struct TestItem {
376        k: String,
377        text: String,
378    }
379
380    impl Embeddable for TestItem {
381        fn key(&self) -> &str {
382            &self.k
383        }
384
385        fn content_hash(&self) -> String {
386            let mut hasher = blake3::Hasher::new();
387            hasher.update(self.text.as_bytes());
388            hasher.finalize().to_hex().to_string()
389        }
390
391        fn embed_text(&self) -> &str {
392            &self.text
393        }
394
395        fn to_payload(&self) -> serde_json::Value {
396            serde_json::json!({"key": self.k, "text": self.text})
397        }
398    }
399
400    fn make_item(k: &str, text: &str) -> TestItem {
401        TestItem {
402            k: k.into(),
403            text: text.into(),
404        }
405    }
406
407    #[test]
408    fn registry_new_valid_url() {
409        let ops = QdrantOps::new("http://localhost:6334").unwrap();
410        let ns = uuid::Uuid::from_bytes([0u8; 16]);
411        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
412        let dbg = format!("{reg:?}");
413        assert!(dbg.contains("EmbeddingRegistry"));
414        assert!(dbg.contains("test_col"));
415    }
416
417    #[test]
418    fn embeddable_content_hash_deterministic() {
419        let item = make_item("key", "some text");
420        assert_eq!(item.content_hash(), item.content_hash());
421    }
422
423    #[test]
424    fn embeddable_content_hash_changes() {
425        let a = make_item("key", "text a");
426        let b = make_item("key", "text b");
427        assert_ne!(a.content_hash(), b.content_hash());
428    }
429
430    #[test]
431    fn embeddable_payload_contains_key() {
432        let item = make_item("my-key", "desc");
433        let payload = item.to_payload();
434        assert_eq!(payload["key"], "my-key");
435    }
436
437    #[test]
438    fn sync_stats_default() {
439        let s = SyncStats::default();
440        assert_eq!(s.added, 0);
441        assert_eq!(s.updated, 0);
442        assert_eq!(s.removed, 0);
443        assert_eq!(s.unchanged, 0);
444    }
445
446    #[test]
447    fn sync_stats_debug() {
448        let s = SyncStats {
449            added: 1,
450            updated: 2,
451            removed: 3,
452            unchanged: 4,
453        };
454        let dbg = format!("{s:?}");
455        assert!(dbg.contains("added"));
456    }
457
458    #[tokio::test]
459    async fn search_raw_embed_fail_returns_error() {
460        let ops = QdrantOps::new("http://localhost:6334").unwrap();
461        let ns = uuid::Uuid::from_bytes([0u8; 16]);
462        let reg = EmbeddingRegistry::new(ops, "test", ns);
463        let embed_fn = |_: &str| -> EmbedFuture {
464            Box::pin(async {
465                Err(Box::new(std::io::Error::other("fail"))
466                    as Box<dyn std::error::Error + Send + Sync>)
467            })
468        };
469        let result = reg.search_raw("query", 5, embed_fn).await;
470        assert!(result.is_err());
471    }
472
473    #[tokio::test]
474    async fn sync_with_unreachable_qdrant_fails() {
475        let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
476        let ns = uuid::Uuid::from_bytes([0u8; 16]);
477        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
478        let items = vec![make_item("k", "text")];
479        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
480        let result = reg.sync(&items, "model", embed_fn).await;
481        assert!(result.is_err());
482    }
483}