Skip to main content

zeph_memory/
embedding_registry.rs

1//! Generic embedding registry backed by Qdrant.
2//!
3//! Provides deduplication through content-hash delta tracking and collection-level
4//! embedding-model change detection.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use qdrant_client::qdrant::{PointStruct, value::Kind};
11
12use crate::QdrantOps;
13use crate::vector_store::VectorStoreError;
14
15/// Boxed future returned by an embedding function.
16pub type EmbedFuture = Pin<
17    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
18>;
19
20/// Trait implemented by domain types that can be stored in an [`EmbeddingRegistry`].
21pub trait Embeddable: Send + Sync {
22    /// Unique string key used for point-ID generation and delta tracking.
23    fn key(&self) -> &str;
24
25    /// blake3 hex hash of all semantically relevant fields.
26    fn content_hash(&self) -> String;
27
28    /// Text that will be embedded (e.g. description).
29    fn embed_text(&self) -> &str;
30
31    /// Full JSON payload to store in Qdrant. **Must** include a `"key"` field
32    /// equal to [`Self::key()`] so [`EmbeddingRegistry`] can recover it on scroll.
33    fn to_payload(&self) -> serde_json::Value;
34}
35
36/// Counters returned by [`EmbeddingRegistry::sync`].
37#[derive(Debug, Default, Clone)]
38pub struct SyncStats {
39    pub added: usize,
40    pub updated: usize,
41    pub removed: usize,
42    pub unchanged: usize,
43}
44
45/// Errors produced by [`EmbeddingRegistry`].
46#[derive(Debug, thiserror::Error)]
47pub enum EmbeddingRegistryError {
48    #[error("vector store error: {0}")]
49    VectorStore(#[from] VectorStoreError),
50
51    #[error("embedding error: {0}")]
52    Embedding(String),
53
54    #[error("serialization error: {0}")]
55    Serialization(String),
56
57    #[error("dimension probe failed: {0}")]
58    DimensionProbe(String),
59}
60
61impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
62    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
63        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
64    }
65}
66
67impl From<serde_json::Error> for EmbeddingRegistryError {
68    fn from(e: serde_json::Error) -> Self {
69        Self::Serialization(e.to_string())
70    }
71}
72
73impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
74    fn from(e: std::num::TryFromIntError) -> Self {
75        Self::DimensionProbe(e.to_string())
76    }
77}
78
79/// Generic Qdrant-backed embedding registry.
80///
81/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
82/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
83/// O(1) delta detection between syncs.
84pub struct EmbeddingRegistry {
85    ops: QdrantOps,
86    collection: String,
87    namespace: uuid::Uuid,
88    hashes: HashMap<String, String>,
89}
90
91impl std::fmt::Debug for EmbeddingRegistry {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("EmbeddingRegistry")
94            .field("collection", &self.collection)
95            .finish_non_exhaustive()
96    }
97}
98
99impl EmbeddingRegistry {
100    /// Create a registry wrapping an existing [`QdrantOps`] connection.
101    #[must_use]
102    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
103        Self {
104            ops,
105            collection: collection.into(),
106            namespace,
107            hashes: HashMap::new(),
108        }
109    }
110
111    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
112    /// unnecessary re-embedding.  Re-creates the collection when the embedding
113    /// model changes.
114    ///
115    /// # Errors
116    ///
117    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
118    pub async fn sync<T: Embeddable>(
119        &mut self,
120        items: &[T],
121        embedding_model: &str,
122        embed_fn: impl Fn(&str) -> EmbedFuture,
123    ) -> Result<SyncStats, EmbeddingRegistryError> {
124        let mut stats = SyncStats::default();
125
126        self.ensure_collection(&embed_fn).await?;
127
128        let existing = self
129            .ops
130            .scroll_all(&self.collection, "key")
131            .await
132            .map_err(|e| {
133                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
134            })?;
135
136        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
137        for item in items {
138            current.insert(item.key().to_owned(), (item.content_hash(), item));
139        }
140
141        let model_changed = existing.values().any(|stored| {
142            stored
143                .get("embedding_model")
144                .is_some_and(|m| m != embedding_model)
145        });
146
147        if model_changed {
148            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
149            self.recreate_collection(&embed_fn).await?;
150        }
151
152        let mut points_to_upsert = Vec::new();
153        for (key, (hash, item)) in &current {
154            let needs_update = if let Some(stored) = existing.get(key) {
155                model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
156            } else {
157                true
158            };
159
160            if !needs_update {
161                stats.unchanged += 1;
162                self.hashes.insert(key.clone(), hash.clone());
163                continue;
164            }
165
166            let vector = match embed_fn(item.embed_text()).await {
167                Ok(v) => v,
168                Err(e) => {
169                    tracing::warn!("failed to embed item '{key}': {e:#}");
170                    continue;
171                }
172            };
173
174            let point_id = self.point_id(key);
175            let mut payload = item.to_payload();
176            if let Some(obj) = payload.as_object_mut() {
177                obj.insert(
178                    "content_hash".into(),
179                    serde_json::Value::String(hash.clone()),
180                );
181                obj.insert(
182                    "embedding_model".into(),
183                    serde_json::Value::String(embedding_model.to_owned()),
184                );
185            }
186            let payload_map = QdrantOps::json_to_payload(payload)?;
187
188            points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
189
190            if existing.contains_key(key) {
191                stats.updated += 1;
192            } else {
193                stats.added += 1;
194            }
195            self.hashes.insert(key.clone(), hash.clone());
196        }
197
198        if !points_to_upsert.is_empty() {
199            self.ops
200                .upsert(&self.collection, points_to_upsert)
201                .await
202                .map_err(|e| {
203                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
204                })?;
205        }
206
207        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
208            .keys()
209            .filter(|key| !current.contains_key(*key))
210            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
211            .collect();
212
213        if !orphan_ids.is_empty() {
214            stats.removed = orphan_ids.len();
215            self.ops
216                .delete_by_ids(&self.collection, orphan_ids)
217                .await
218                .map_err(|e| {
219                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
220                })?;
221        }
222
223        tracing::info!(
224            added = stats.added,
225            updated = stats.updated,
226            removed = stats.removed,
227            unchanged = stats.unchanged,
228            collection = &self.collection,
229            "embeddings synced"
230        );
231
232        Ok(stats)
233    }
234
235    /// Search the collection, returning raw scored Qdrant points.
236    ///
237    /// Consumers map the payloads to their domain types.
238    ///
239    /// # Errors
240    ///
241    /// Returns [`EmbeddingRegistryError`] if embedding or Qdrant search fails.
242    pub async fn search_raw(
243        &self,
244        query: &str,
245        limit: usize,
246        embed_fn: impl Fn(&str) -> EmbedFuture,
247    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
248        let query_vec = embed_fn(query)
249            .await
250            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
251
252        let Ok(limit_u64) = u64::try_from(limit) else {
253            return Ok(Vec::new());
254        };
255
256        let results = self
257            .ops
258            .search(&self.collection, query_vec, limit_u64, None)
259            .await
260            .map_err(|e| {
261                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
262            })?;
263
264        let scored: Vec<crate::ScoredVectorPoint> = results
265            .into_iter()
266            .map(|point| {
267                let payload: HashMap<String, serde_json::Value> = point
268                    .payload
269                    .into_iter()
270                    .filter_map(|(k, v)| {
271                        let json_val = match v.kind? {
272                            Kind::StringValue(s) => serde_json::Value::String(s),
273                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
274                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
275                            Kind::DoubleValue(d) => {
276                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
277                            }
278                            _ => return None,
279                        };
280                        Some((k, json_val))
281                    })
282                    .collect();
283
284                let id = match point.id.and_then(|pid| pid.point_id_options) {
285                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
286                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
287                    None => String::new(),
288                };
289
290                crate::ScoredVectorPoint {
291                    id,
292                    score: point.score,
293                    payload,
294                }
295            })
296            .collect();
297
298        Ok(scored)
299    }
300
301    fn point_id(&self, key: &str) -> String {
302        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
303    }
304
305    async fn ensure_collection(
306        &self,
307        embed_fn: &impl Fn(&str) -> EmbedFuture,
308    ) -> Result<(), EmbeddingRegistryError> {
309        if self
310            .ops
311            .collection_exists(&self.collection)
312            .await
313            .map_err(|e| {
314                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
315            })?
316        {
317            return Ok(());
318        }
319
320        let probe = embed_fn("dimension probe")
321            .await
322            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
323        let vector_size = u64::try_from(probe.len())?;
324
325        self.ops
326            .ensure_collection(&self.collection, vector_size)
327            .await
328            .map_err(|e| {
329                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
330            })?;
331
332        tracing::info!(
333            collection = &self.collection,
334            dimensions = vector_size,
335            "created Qdrant collection"
336        );
337
338        Ok(())
339    }
340
341    async fn recreate_collection(
342        &self,
343        embed_fn: &impl Fn(&str) -> EmbedFuture,
344    ) -> Result<(), EmbeddingRegistryError> {
345        if self
346            .ops
347            .collection_exists(&self.collection)
348            .await
349            .map_err(|e| {
350                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
351            })?
352        {
353            self.ops
354                .delete_collection(&self.collection)
355                .await
356                .map_err(|e| {
357                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
358                })?;
359            tracing::info!(
360                collection = &self.collection,
361                "deleted collection for recreation"
362            );
363        }
364        self.ensure_collection(embed_fn).await
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    struct TestItem {
373        k: String,
374        text: String,
375    }
376
377    impl Embeddable for TestItem {
378        fn key(&self) -> &str {
379            &self.k
380        }
381
382        fn content_hash(&self) -> String {
383            let mut hasher = blake3::Hasher::new();
384            hasher.update(self.text.as_bytes());
385            hasher.finalize().to_hex().to_string()
386        }
387
388        fn embed_text(&self) -> &str {
389            &self.text
390        }
391
392        fn to_payload(&self) -> serde_json::Value {
393            serde_json::json!({"key": self.k, "text": self.text})
394        }
395    }
396
397    fn make_item(k: &str, text: &str) -> TestItem {
398        TestItem {
399            k: k.into(),
400            text: text.into(),
401        }
402    }
403
404    #[test]
405    fn registry_new_valid_url() {
406        let ops = QdrantOps::new("http://localhost:6334").unwrap();
407        let ns = uuid::Uuid::from_bytes([0u8; 16]);
408        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
409        let dbg = format!("{reg:?}");
410        assert!(dbg.contains("EmbeddingRegistry"));
411        assert!(dbg.contains("test_col"));
412    }
413
414    #[test]
415    fn embeddable_content_hash_deterministic() {
416        let item = make_item("key", "some text");
417        assert_eq!(item.content_hash(), item.content_hash());
418    }
419
420    #[test]
421    fn embeddable_content_hash_changes() {
422        let a = make_item("key", "text a");
423        let b = make_item("key", "text b");
424        assert_ne!(a.content_hash(), b.content_hash());
425    }
426
427    #[test]
428    fn embeddable_payload_contains_key() {
429        let item = make_item("my-key", "desc");
430        let payload = item.to_payload();
431        assert_eq!(payload["key"], "my-key");
432    }
433
434    #[test]
435    fn sync_stats_default() {
436        let s = SyncStats::default();
437        assert_eq!(s.added, 0);
438        assert_eq!(s.updated, 0);
439        assert_eq!(s.removed, 0);
440        assert_eq!(s.unchanged, 0);
441    }
442
443    #[test]
444    fn sync_stats_debug() {
445        let s = SyncStats {
446            added: 1,
447            updated: 2,
448            removed: 3,
449            unchanged: 4,
450        };
451        let dbg = format!("{s:?}");
452        assert!(dbg.contains("added"));
453    }
454
455    #[tokio::test]
456    async fn search_raw_embed_fail_returns_error() {
457        let ops = QdrantOps::new("http://localhost:6334").unwrap();
458        let ns = uuid::Uuid::from_bytes([0u8; 16]);
459        let reg = EmbeddingRegistry::new(ops, "test", ns);
460        let embed_fn = |_: &str| -> EmbedFuture {
461            Box::pin(async {
462                Err(Box::new(std::io::Error::other("fail"))
463                    as Box<dyn std::error::Error + Send + Sync>)
464            })
465        };
466        let result = reg.search_raw("query", 5, embed_fn).await;
467        assert!(result.is_err());
468    }
469
470    #[tokio::test]
471    async fn sync_with_unreachable_qdrant_fails() {
472        let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
473        let ns = uuid::Uuid::from_bytes([0u8; 16]);
474        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
475        let items = vec![make_item("k", "text")];
476        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
477        let result = reg.sync(&items, "model", embed_fn).await;
478        assert!(result.is_err());
479    }
480}