Skip to main content

zeph_memory/
embedding_registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Generic embedding registry backed by Qdrant.
5//!
6//! Provides deduplication through content-hash delta tracking and collection-level
7//! embedding-model change detection.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use qdrant_client::qdrant::{PointStruct, value::Kind};
14
15use crate::QdrantOps;
16use crate::vector_store::VectorStoreError;
17
18/// Boxed future returned by an embedding function.
19pub type EmbedFuture = Pin<
20    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
21>;
22
23/// Trait implemented by domain types that can be stored in an [`EmbeddingRegistry`].
24pub trait Embeddable: Send + Sync {
25    /// Unique string key used for point-ID generation and delta tracking.
26    fn key(&self) -> &str;
27
28    /// blake3 hex hash of all semantically relevant fields.
29    fn content_hash(&self) -> String;
30
31    /// Text that will be embedded (e.g. description).
32    fn embed_text(&self) -> &str;
33
34    /// Full JSON payload to store in Qdrant. **Must** include a `"key"` field
35    /// equal to [`Self::key()`] so [`EmbeddingRegistry`] can recover it on scroll.
36    fn to_payload(&self) -> serde_json::Value;
37}
38
39/// Counters returned by [`EmbeddingRegistry::sync`].
40#[derive(Debug, Default, Clone)]
41pub struct SyncStats {
42    pub added: usize,
43    pub updated: usize,
44    pub removed: usize,
45    pub unchanged: usize,
46}
47
48/// Errors produced by [`EmbeddingRegistry`].
49#[derive(Debug, thiserror::Error)]
50pub enum EmbeddingRegistryError {
51    #[error("vector store error: {0}")]
52    VectorStore(#[from] VectorStoreError),
53
54    #[error("embedding error: {0}")]
55    Embedding(String),
56
57    #[error("serialization error: {0}")]
58    Serialization(String),
59
60    #[error("dimension probe failed: {0}")]
61    DimensionProbe(String),
62}
63
64impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
65    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
66        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
67    }
68}
69
70impl From<serde_json::Error> for EmbeddingRegistryError {
71    fn from(e: serde_json::Error) -> Self {
72        Self::Serialization(e.to_string())
73    }
74}
75
76impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
77    fn from(e: std::num::TryFromIntError) -> Self {
78        Self::DimensionProbe(e.to_string())
79    }
80}
81
82/// Generic Qdrant-backed embedding registry.
83///
84/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
85/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
86/// O(1) delta detection between syncs.
87#[derive(Clone)]
88pub struct EmbeddingRegistry {
89    ops: QdrantOps,
90    collection: String,
91    namespace: uuid::Uuid,
92    hashes: HashMap<String, String>,
93}
94
95impl std::fmt::Debug for EmbeddingRegistry {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_struct("EmbeddingRegistry")
98            .field("collection", &self.collection)
99            .finish_non_exhaustive()
100    }
101}
102
103impl EmbeddingRegistry {
104    /// Create a registry wrapping an existing [`QdrantOps`] connection.
105    #[must_use]
106    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
107        Self {
108            ops,
109            collection: collection.into(),
110            namespace,
111            hashes: HashMap::new(),
112        }
113    }
114
115    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
116    /// unnecessary re-embedding.  Re-creates the collection when the embedding
117    /// model changes.
118    ///
119    /// # Errors
120    ///
121    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
122    pub async fn sync<T: Embeddable>(
123        &mut self,
124        items: &[T],
125        embedding_model: &str,
126        embed_fn: impl Fn(&str) -> EmbedFuture,
127    ) -> Result<SyncStats, EmbeddingRegistryError> {
128        let mut stats = SyncStats::default();
129
130        self.ensure_collection(&embed_fn).await?;
131
132        let existing = self
133            .ops
134            .scroll_all(&self.collection, "key")
135            .await
136            .map_err(|e| {
137                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
138            })?;
139
140        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
141        for item in items {
142            current.insert(item.key().to_owned(), (item.content_hash(), item));
143        }
144
145        let model_changed = existing.values().any(|stored| {
146            stored
147                .get("embedding_model")
148                .is_some_and(|m| m != embedding_model)
149        });
150
151        if model_changed {
152            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
153            self.recreate_collection(&embed_fn).await?;
154        }
155
156        let mut points_to_upsert = Vec::new();
157        for (key, (hash, item)) in &current {
158            let needs_update = if let Some(stored) = existing.get(key) {
159                model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
160            } else {
161                true
162            };
163
164            if !needs_update {
165                stats.unchanged += 1;
166                self.hashes.insert(key.clone(), hash.clone());
167                continue;
168            }
169
170            let vector = match embed_fn(item.embed_text()).await {
171                Ok(v) => v,
172                Err(e) => {
173                    tracing::warn!("failed to embed item '{key}': {e:#}");
174                    continue;
175                }
176            };
177
178            let point_id = self.point_id(key);
179            let mut payload = item.to_payload();
180            if let Some(obj) = payload.as_object_mut() {
181                obj.insert(
182                    "content_hash".into(),
183                    serde_json::Value::String(hash.clone()),
184                );
185                obj.insert(
186                    "embedding_model".into(),
187                    serde_json::Value::String(embedding_model.to_owned()),
188                );
189            }
190            let payload_map = QdrantOps::json_to_payload(payload)?;
191
192            points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
193
194            if existing.contains_key(key) {
195                stats.updated += 1;
196            } else {
197                stats.added += 1;
198            }
199            self.hashes.insert(key.clone(), hash.clone());
200        }
201
202        if !points_to_upsert.is_empty() {
203            self.ops
204                .upsert(&self.collection, points_to_upsert)
205                .await
206                .map_err(|e| {
207                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
208                })?;
209        }
210
211        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
212            .keys()
213            .filter(|key| !current.contains_key(*key))
214            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
215            .collect();
216
217        if !orphan_ids.is_empty() {
218            stats.removed = orphan_ids.len();
219            self.ops
220                .delete_by_ids(&self.collection, orphan_ids)
221                .await
222                .map_err(|e| {
223                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
224                })?;
225        }
226
227        tracing::info!(
228            added = stats.added,
229            updated = stats.updated,
230            removed = stats.removed,
231            unchanged = stats.unchanged,
232            collection = &self.collection,
233            "embeddings synced"
234        );
235
236        Ok(stats)
237    }
238
239    /// Search the collection, returning raw scored Qdrant points.
240    ///
241    /// Consumers map the payloads to their domain types.
242    ///
243    /// # Errors
244    ///
245    /// Returns [`EmbeddingRegistryError`] if embedding or Qdrant search fails.
246    pub async fn search_raw(
247        &self,
248        query: &str,
249        limit: usize,
250        embed_fn: impl Fn(&str) -> EmbedFuture,
251    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
252        let query_vec = embed_fn(query)
253            .await
254            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
255
256        let Ok(limit_u64) = u64::try_from(limit) else {
257            return Ok(Vec::new());
258        };
259
260        let results = self
261            .ops
262            .search(&self.collection, query_vec, limit_u64, None)
263            .await
264            .map_err(|e| {
265                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
266            })?;
267
268        let scored: Vec<crate::ScoredVectorPoint> = results
269            .into_iter()
270            .map(|point| {
271                let payload: HashMap<String, serde_json::Value> = point
272                    .payload
273                    .into_iter()
274                    .filter_map(|(k, v)| {
275                        let json_val = match v.kind? {
276                            Kind::StringValue(s) => serde_json::Value::String(s),
277                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
278                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
279                            Kind::DoubleValue(d) => {
280                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
281                            }
282                            _ => return None,
283                        };
284                        Some((k, json_val))
285                    })
286                    .collect();
287
288                let id = match point.id.and_then(|pid| pid.point_id_options) {
289                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
290                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
291                    None => String::new(),
292                };
293
294                crate::ScoredVectorPoint {
295                    id,
296                    score: point.score,
297                    payload,
298                }
299            })
300            .collect();
301
302        Ok(scored)
303    }
304
305    fn point_id(&self, key: &str) -> String {
306        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
307    }
308
309    async fn ensure_collection(
310        &self,
311        embed_fn: &impl Fn(&str) -> EmbedFuture,
312    ) -> Result<(), EmbeddingRegistryError> {
313        if !self
314            .ops
315            .collection_exists(&self.collection)
316            .await
317            .map_err(|e| {
318                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
319            })?
320        {
321            // Collection does not exist — probe once and create.
322            let vector_size = self.probe_vector_size(embed_fn).await?;
323            self.ops
324                .ensure_collection(&self.collection, vector_size)
325                .await
326                .map_err(|e| {
327                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
328                })?;
329            tracing::info!(
330                collection = &self.collection,
331                dimensions = vector_size,
332                "created Qdrant collection"
333            );
334            return Ok(());
335        }
336
337        let existing_size = self
338            .ops
339            .client()
340            .collection_info(&self.collection)
341            .await
342            .map_err(|e| {
343                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
344            })?
345            .result
346            .and_then(|info| info.config)
347            .and_then(|cfg| cfg.params)
348            .and_then(|params| params.vectors_config)
349            .and_then(|vc| vc.config)
350            .and_then(|cfg| match cfg {
351                qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
352                // Named-vector collections (ParamsMap) are not supported by this registry;
353                // treat size as unknown and recreate to ensure a compatible single-vector layout.
354                qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
355            });
356
357        let vector_size = self.probe_vector_size(embed_fn).await?;
358
359        if existing_size == Some(vector_size) {
360            return Ok(());
361        }
362
363        tracing::warn!(
364            collection = &self.collection,
365            existing = ?existing_size,
366            required = vector_size,
367            "vector dimension mismatch, recreating collection"
368        );
369        self.ops
370            .delete_collection(&self.collection)
371            .await
372            .map_err(|e| {
373                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
374            })?;
375        self.ops
376            .ensure_collection(&self.collection, vector_size)
377            .await
378            .map_err(|e| {
379                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
380            })?;
381        tracing::info!(
382            collection = &self.collection,
383            dimensions = vector_size,
384            "created Qdrant collection"
385        );
386
387        Ok(())
388    }
389
390    async fn probe_vector_size(
391        &self,
392        embed_fn: &impl Fn(&str) -> EmbedFuture,
393    ) -> Result<u64, EmbeddingRegistryError> {
394        let probe = embed_fn("dimension probe")
395            .await
396            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
397        Ok(u64::try_from(probe.len())?)
398    }
399
400    async fn recreate_collection(
401        &self,
402        embed_fn: &impl Fn(&str) -> EmbedFuture,
403    ) -> Result<(), EmbeddingRegistryError> {
404        if self
405            .ops
406            .collection_exists(&self.collection)
407            .await
408            .map_err(|e| {
409                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
410            })?
411        {
412            self.ops
413                .delete_collection(&self.collection)
414                .await
415                .map_err(|e| {
416                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
417                })?;
418            tracing::info!(
419                collection = &self.collection,
420                "deleted collection for recreation"
421            );
422        }
423        self.ensure_collection(embed_fn).await
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    struct TestItem {
432        k: String,
433        text: String,
434    }
435
436    impl Embeddable for TestItem {
437        fn key(&self) -> &str {
438            &self.k
439        }
440
441        fn content_hash(&self) -> String {
442            let mut hasher = blake3::Hasher::new();
443            hasher.update(self.text.as_bytes());
444            hasher.finalize().to_hex().to_string()
445        }
446
447        fn embed_text(&self) -> &str {
448            &self.text
449        }
450
451        fn to_payload(&self) -> serde_json::Value {
452            serde_json::json!({"key": self.k, "text": self.text})
453        }
454    }
455
456    fn make_item(k: &str, text: &str) -> TestItem {
457        TestItem {
458            k: k.into(),
459            text: text.into(),
460        }
461    }
462
463    #[test]
464    fn registry_new_valid_url() {
465        let ops = QdrantOps::new("http://localhost:6334").unwrap();
466        let ns = uuid::Uuid::from_bytes([0u8; 16]);
467        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
468        let dbg = format!("{reg:?}");
469        assert!(dbg.contains("EmbeddingRegistry"));
470        assert!(dbg.contains("test_col"));
471    }
472
473    #[test]
474    fn embeddable_content_hash_deterministic() {
475        let item = make_item("key", "some text");
476        assert_eq!(item.content_hash(), item.content_hash());
477    }
478
479    #[test]
480    fn embeddable_content_hash_changes() {
481        let a = make_item("key", "text a");
482        let b = make_item("key", "text b");
483        assert_ne!(a.content_hash(), b.content_hash());
484    }
485
486    #[test]
487    fn embeddable_payload_contains_key() {
488        let item = make_item("my-key", "desc");
489        let payload = item.to_payload();
490        assert_eq!(payload["key"], "my-key");
491    }
492
493    #[test]
494    fn sync_stats_default() {
495        let s = SyncStats::default();
496        assert_eq!(s.added, 0);
497        assert_eq!(s.updated, 0);
498        assert_eq!(s.removed, 0);
499        assert_eq!(s.unchanged, 0);
500    }
501
502    #[test]
503    fn sync_stats_debug() {
504        let s = SyncStats {
505            added: 1,
506            updated: 2,
507            removed: 3,
508            unchanged: 4,
509        };
510        let dbg = format!("{s:?}");
511        assert!(dbg.contains("added"));
512    }
513
514    #[tokio::test]
515    async fn search_raw_embed_fail_returns_error() {
516        let ops = QdrantOps::new("http://localhost:6334").unwrap();
517        let ns = uuid::Uuid::from_bytes([0u8; 16]);
518        let reg = EmbeddingRegistry::new(ops, "test", ns);
519        let embed_fn = |_: &str| -> EmbedFuture {
520            Box::pin(async {
521                Err(Box::new(std::io::Error::other("fail"))
522                    as Box<dyn std::error::Error + Send + Sync>)
523            })
524        };
525        let result = reg.search_raw("query", 5, embed_fn).await;
526        assert!(result.is_err());
527    }
528
529    #[tokio::test]
530    async fn sync_with_unreachable_qdrant_fails() {
531        let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
532        let ns = uuid::Uuid::from_bytes([0u8; 16]);
533        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
534        let items = vec![make_item("k", "text")];
535        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
536        let result = reg.sync(&items, "model", embed_fn).await;
537        assert!(result.is_err());
538    }
539}