Skip to main content

zeph_memory/
embedding_registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Generic embedding registry backed by Qdrant.
5//!
6//! Provides deduplication through content-hash delta tracking and collection-level
7//! embedding-model change detection.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use tokio::sync::RwLock;
15
16use futures::StreamExt as _;
17use qdrant_client::qdrant::{PointStruct, value::Kind};
18
19use crate::QdrantOps;
20use crate::vector_store::VectorStoreError;
21
22/// Boxed future returned by an embedding function.
23pub type EmbedFuture = Pin<
24    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
25>;
26
27/// Domain type that can be stored in an [`EmbeddingRegistry`].
28///
29/// Implement this trait for any struct that should be embedded and persisted in Qdrant.
30/// The registry uses [`key`](Self::key) and [`content_hash`](Self::content_hash) to
31/// detect which items need to be re-embedded on each [`EmbeddingRegistry::sync`] call.
32pub trait Embeddable: Send + Sync {
33    /// Unique string key used for point-ID generation and delta tracking.
34    fn key(&self) -> &str;
35
36    /// BLAKE3 hex hash of all semantically relevant fields.
37    ///
38    /// When this hash changes between syncs the item's embedding is recomputed.
39    fn content_hash(&self) -> String;
40
41    /// Text that will be passed to the embedding model.
42    fn embed_text(&self) -> &str;
43
44    /// Full JSON payload to store in Qdrant alongside the vector.
45    ///
46    /// **Must** include a `"key"` field equal to [`Self::key()`] so
47    /// [`EmbeddingRegistry`] can recover items on scroll.
48    fn to_payload(&self) -> serde_json::Value;
49}
50
51/// Counters returned by [`EmbeddingRegistry::sync`].
52#[derive(Debug, Default, Clone)]
53pub struct SyncStats {
54    pub added: usize,
55    pub updated: usize,
56    pub removed: usize,
57    pub unchanged: usize,
58}
59
60/// Errors produced by [`EmbeddingRegistry`].
61#[derive(Debug, thiserror::Error)]
62#[non_exhaustive]
63pub enum EmbeddingRegistryError {
64    #[error("vector store error: {0}")]
65    VectorStore(#[from] VectorStoreError),
66
67    #[error("embedding error: {0}")]
68    Embedding(String),
69
70    #[error("serialization error: {0}")]
71    Serialization(String),
72
73    #[error("dimension probe failed: {0}")]
74    DimensionProbe(String),
75}
76
77impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
78    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
79        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
80    }
81}
82
83impl From<serde_json::Error> for EmbeddingRegistryError {
84    fn from(e: serde_json::Error) -> Self {
85        Self::Serialization(e.to_string())
86    }
87}
88
89// Ollama appends :latest when no tag is specified; treat the two as equivalent.
90fn normalize_model_name(name: &str) -> &str {
91    name.strip_suffix(":latest").unwrap_or(name)
92}
93
94/// Returns `true` when any stored point uses a model name that is semantically different
95/// from `config_model` after normalizing `:latest` suffixes.
96///
97/// A missing `embedding_model` field (legacy points from pre-#3395 sessions) is treated as a
98/// mismatch: the vector was produced by an unknown model and must be regenerated.
99fn model_has_changed(
100    existing: &HashMap<String, HashMap<String, String>>,
101    config_model: &str,
102) -> bool {
103    if config_model.is_empty() {
104        return false;
105    }
106    existing
107        .values()
108        .any(|stored| match stored.get("embedding_model") {
109            Some(m) => normalize_model_name(m) != normalize_model_name(config_model),
110            // Absent field means the point was written before the model was recorded; treat as mismatch.
111            None => true,
112        })
113}
114
115/// Generic Qdrant-backed embedding registry.
116///
117/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
118/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
119/// O(1) delta detection between syncs.
120///
121/// The `cached_dim` field caches the collection's vector dimension after the first successful
122/// [`sync`](Self::sync) so that [`search_raw`](Self::search_raw) can validate the query vector
123/// dimension without an extra Qdrant round-trip on every call.  When a mismatch is detected,
124/// `search_raw` returns [`EmbeddingRegistryError::DimensionProbe`] instead of silently issuing a
125/// gRPC search that would return near-zero cosine scores (Qdrant gRPC behaviour on dim mismatch).
126#[derive(Clone)]
127pub struct EmbeddingRegistry {
128    ops: QdrantOps,
129    collection: String,
130    namespace: uuid::Uuid,
131    hashes: HashMap<String, String>,
132    /// Maximum number of embedding requests dispatched concurrently during a sync.
133    pub concurrency: usize,
134    /// Vector dimension confirmed during the last successful `sync`.  Shared via `Arc` so
135    /// `Clone` works without invalidating the cached value across cloned instances.
136    cached_dim: Arc<RwLock<Option<u64>>>,
137}
138
139impl std::fmt::Debug for EmbeddingRegistry {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("EmbeddingRegistry")
142            .field("collection", &self.collection)
143            .finish_non_exhaustive()
144    }
145}
146
147impl EmbeddingRegistry {
148    /// Create a registry wrapping an existing [`QdrantOps`] connection.
149    #[must_use]
150    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
151        Self {
152            ops,
153            collection: collection.into(),
154            namespace,
155            hashes: HashMap::new(),
156            concurrency: 4,
157            cached_dim: Arc::new(RwLock::new(None)),
158        }
159    }
160
161    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
162    /// unnecessary re-embedding.  Re-creates the collection when the embedding
163    /// model changes.
164    ///
165    /// `on_progress`, when provided, is called after each successful embed+upsert with
166    /// `(completed, total)` counts so callers can display progress indicators.
167    ///
168    /// # Errors
169    ///
170    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
171    pub async fn sync<T: Embeddable>(
172        &mut self,
173        items: &[T],
174        embedding_model: &str,
175        embed_fn: impl Fn(&str) -> EmbedFuture,
176        on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
177    ) -> Result<SyncStats, EmbeddingRegistryError> {
178        let mut stats = SyncStats::default();
179
180        self.ensure_collection(&embed_fn).await?;
181
182        let existing = self
183            .ops
184            .scroll_all(&self.collection, "key")
185            .await
186            .map_err(|e| {
187                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
188            })?;
189
190        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
191        for item in items {
192            current.insert(item.key().to_owned(), (item.content_hash(), item));
193        }
194
195        let model_changed = model_has_changed(&existing, embedding_model);
196
197        if model_changed {
198            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
199            self.recreate_collection(&embed_fn).await?;
200        }
201
202        let work_items = build_work_set(
203            &current,
204            &existing,
205            model_changed,
206            &mut stats,
207            &mut self.hashes,
208        );
209
210        // Pre-create futures, point IDs, and payloads before taking the mutable borrow on
211        // self.hashes to avoid a double-borrow on `self`.
212        let work_with_futures: Vec<(String, String, EmbedFuture, String, serde_json::Value)> =
213            work_items
214                .into_iter()
215                .map(|(key, hash, item)| {
216                    let text = item.embed_text().to_owned();
217                    let fut = embed_fn(&text);
218                    let point_id = self.point_id(&key);
219                    let payload = item.to_payload();
220                    (key, hash, fut, point_id, payload)
221                })
222                .collect();
223
224        let points_to_upsert = embed_and_collect_points(
225            work_with_futures,
226            on_progress,
227            &existing,
228            embedding_model,
229            self.concurrency,
230            &mut stats,
231            &mut self.hashes,
232        )
233        .await?;
234
235        if !points_to_upsert.is_empty() {
236            self.ops
237                .upsert(&self.collection, points_to_upsert)
238                .await
239                .map_err(|e| {
240                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
241                })?;
242        }
243
244        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
245            .keys()
246            .filter(|key| !current.contains_key(*key))
247            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
248            .collect();
249
250        if !orphan_ids.is_empty() {
251            stats.removed = orphan_ids.len();
252            self.ops
253                .delete_by_ids(&self.collection, orphan_ids)
254                .await
255                .map_err(|e| {
256                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
257                })?;
258        }
259
260        tracing::info!(
261            added = stats.added,
262            updated = stats.updated,
263            removed = stats.removed,
264            unchanged = stats.unchanged,
265            collection = &self.collection,
266            "embeddings synced"
267        );
268
269        Ok(stats)
270    }
271
272    /// Search the collection, returning raw scored Qdrant points.
273    ///
274    /// Validates that the query vector dimension matches the collection before issuing the gRPC
275    /// call.  Qdrant gRPC silently returns near-zero cosine scores (~0.022) when dimensions
276    /// mismatch instead of returning an error — this guard prevents that silent failure.
277    ///
278    /// The dimension is checked against the cache populated by the most recent [`sync`](Self::sync)
279    /// call.  If no sync has occurred (cache is `None`) the check is skipped to avoid blocking
280    /// reads before the first sync.
281    ///
282    /// Consumers map the payloads to their domain types.
283    ///
284    /// # Errors
285    ///
286    /// Returns [`EmbeddingRegistryError::DimensionProbe`] when the query vector dimension does not
287    /// match the stored collection dimension.  Returns [`EmbeddingRegistryError::Embedding`] if the
288    /// embed function fails, or [`EmbeddingRegistryError::VectorStore`] on Qdrant search failure.
289    pub async fn search_raw(
290        &self,
291        query: &str,
292        limit: usize,
293        embed_fn: impl Fn(&str) -> EmbedFuture,
294    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
295        let query_vec = embed_fn(query)
296            .await
297            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
298
299        // Guard: Qdrant gRPC returns near-zero cosine scores when the query vector dimension
300        // does not match the stored collection dimension (issue #3418).  Check the cache first
301        // (populated by sync); fall back to a live Qdrant probe only when the cache is empty.
302        let collection_dim: Option<u64> = *self.cached_dim.read().await;
303
304        let collection_dim = if collection_dim.is_some() {
305            collection_dim
306        } else {
307            // Cache miss: ask Qdrant directly (first search before any sync), then populate cache.
308            let probed = self
309                .ops
310                .get_collection_vector_size(&self.collection)
311                .await
312                .map_err(|e| {
313                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
314                })?;
315            if let Some(d) = probed {
316                self.set_cached_dim(d).await;
317            }
318            probed
319        };
320
321        if let Some(stored_dim) = collection_dim {
322            // Safe: a Vec<f32> with 4B+ elements is impossible in practice on any 64-bit platform.
323            let query_dim = query_vec.len() as u64;
324            if query_dim != stored_dim {
325                return Err(EmbeddingRegistryError::DimensionProbe(format!(
326                    "query vector dimension {query_dim} does not match collection '{}' \
327                     dimension {stored_dim}; re-run sync to rebuild the collection",
328                    self.collection
329                )));
330            }
331        }
332
333        let Ok(limit_u64) = u64::try_from(limit) else {
334            return Ok(Vec::new());
335        };
336
337        let results = self
338            .ops
339            .search(&self.collection, query_vec, limit_u64, None)
340            .await
341            .map_err(|e| {
342                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
343            })?;
344
345        let scored: Vec<crate::ScoredVectorPoint> = results
346            .into_iter()
347            .map(|point| {
348                let payload: HashMap<String, serde_json::Value> = point
349                    .payload
350                    .into_iter()
351                    .filter_map(|(k, v)| {
352                        let json_val = match v.kind? {
353                            Kind::StringValue(s) => serde_json::Value::String(s),
354                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
355                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
356                            Kind::DoubleValue(d) => {
357                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
358                            }
359                            _ => return None,
360                        };
361                        Some((k, json_val))
362                    })
363                    .collect();
364
365                let id = match point.id.and_then(|pid| pid.point_id_options) {
366                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
367                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
368                    None => String::new(),
369                };
370
371                crate::ScoredVectorPoint {
372                    id,
373                    score: point.score,
374                    payload,
375                }
376            })
377            .collect();
378
379        Ok(scored)
380    }
381
382    fn point_id(&self, key: &str) -> String {
383        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
384    }
385
386    async fn ensure_collection(
387        &self,
388        embed_fn: &impl Fn(&str) -> EmbedFuture,
389    ) -> Result<(), EmbeddingRegistryError> {
390        if !self
391            .ops
392            .collection_exists(&self.collection)
393            .await
394            .map_err(|e| {
395                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
396            })?
397        {
398            // Collection does not exist — probe once and create.
399            let vector_size = self.probe_vector_size(embed_fn).await?;
400            self.ops
401                .ensure_collection(&self.collection, vector_size)
402                .await
403                .map_err(|e| {
404                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
405                })?;
406            tracing::info!(
407                collection = &self.collection,
408                dimensions = vector_size,
409                "created Qdrant collection"
410            );
411            self.set_cached_dim(vector_size).await;
412            return Ok(());
413        }
414
415        let existing_size = self
416            .ops
417            .client()
418            .collection_info(&self.collection)
419            .await
420            .map_err(|e| {
421                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
422            })?
423            .result
424            .and_then(|info| info.config)
425            .and_then(|cfg| cfg.params)
426            .and_then(|params| params.vectors_config)
427            .and_then(|vc| vc.config)
428            .and_then(|cfg| match cfg {
429                qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
430                // Named-vector collections (ParamsMap) are not supported by this registry;
431                // treat size as unknown and recreate to ensure a compatible single-vector layout.
432                qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
433            });
434
435        let vector_size = self.probe_vector_size(embed_fn).await?;
436
437        if existing_size == Some(vector_size) {
438            self.set_cached_dim(vector_size).await;
439            return Ok(());
440        }
441
442        tracing::warn!(
443            collection = &self.collection,
444            existing = ?existing_size,
445            required = vector_size,
446            "vector dimension mismatch, recreating collection"
447        );
448        self.ops
449            .delete_collection(&self.collection)
450            .await
451            .map_err(|e| {
452                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
453            })?;
454        self.ops
455            .ensure_collection(&self.collection, vector_size)
456            .await
457            .map_err(|e| {
458                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
459            })?;
460        tracing::info!(
461            collection = &self.collection,
462            dimensions = vector_size,
463            "created Qdrant collection"
464        );
465        self.set_cached_dim(vector_size).await;
466
467        Ok(())
468    }
469
470    /// Store `dim` in the dimension cache so `search_raw` can validate without a Qdrant round-trip.
471    async fn set_cached_dim(&self, dim: u64) {
472        *self.cached_dim.write().await = Some(dim);
473    }
474
475    async fn probe_vector_size(
476        &self,
477        embed_fn: &impl Fn(&str) -> EmbedFuture,
478    ) -> Result<u64, EmbeddingRegistryError> {
479        let probe = embed_fn("dimension probe")
480            .await
481            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
482        // Safe: a Vec<f32> with 4B+ elements is impossible in practice on any 64-bit platform.
483        Ok(probe.len() as u64)
484    }
485
486    async fn recreate_collection(
487        &self,
488        embed_fn: &impl Fn(&str) -> EmbedFuture,
489    ) -> Result<(), EmbeddingRegistryError> {
490        if self
491            .ops
492            .collection_exists(&self.collection)
493            .await
494            .map_err(|e| {
495                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
496            })?
497        {
498            self.ops
499                .delete_collection(&self.collection)
500                .await
501                .map_err(|e| {
502                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
503                })?;
504            tracing::info!(
505                collection = &self.collection,
506                "deleted collection for recreation"
507            );
508        }
509        self.ensure_collection(embed_fn).await
510    }
511}
512
513/// Determine which items need embedding and update stats for unchanged ones.
514///
515/// Returns a list of `(key, hash, item)` triples that require re-embedding.  Items whose
516/// stored hash matches the current hash are counted as `unchanged` in `stats` and their
517/// hashes are pre-populated in the `hashes` map.
518fn build_work_set<'a, T: Embeddable>(
519    current: &HashMap<String, (String, &'a T)>,
520    existing: &HashMap<String, HashMap<String, String>>,
521    model_changed: bool,
522    stats: &mut SyncStats,
523    hashes: &mut HashMap<String, String>,
524) -> Vec<(String, String, &'a T)> {
525    let mut work_items: Vec<(String, String, &'a T)> = Vec::new();
526    for (key, (hash, item)) in current {
527        let needs_update = if let Some(stored) = existing.get(key) {
528            model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
529        } else {
530            true
531        };
532
533        if needs_update {
534            work_items.push((key.clone(), hash.clone(), *item));
535        } else {
536            stats.unchanged += 1;
537            hashes.insert(key.clone(), hash.clone());
538        }
539    }
540    work_items
541}
542
543/// Await each pre-created embed future and collect the resulting Qdrant points.
544///
545/// Await each pre-created embed future and collect the resulting Qdrant points.
546///
547/// `work_items` is `(key, hash, embed_future, point_id, item_payload)` — point IDs and payloads
548/// must be pre-computed to avoid a double-borrow on the `EmbeddingRegistry` when `hashes` is
549/// mutably borrowed.
550///
551/// Processes futures with bounded concurrency (`concurrency` parameter).  Calls `on_progress`
552/// after each successful embed.  Updates `stats.added`/`stats.updated` and `hashes` in place.
553///
554/// Returns a `Vec<PointStruct>` ready for upsert, or an error if payload serialization fails.
555#[allow(clippy::too_many_arguments)]
556async fn embed_and_collect_points(
557    work_items: Vec<(String, String, EmbedFuture, String, serde_json::Value)>,
558    on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
559    existing: &HashMap<String, HashMap<String, String>>,
560    embedding_model: &str,
561    concurrency: usize,
562    stats: &mut SyncStats,
563    hashes: &mut HashMap<String, String>,
564) -> Result<Vec<PointStruct>, EmbeddingRegistryError> {
565    let total = work_items.len();
566    // Clamp concurrency to at least 1: buffer_unordered(0) silently skips all futures.
567    let concurrency = concurrency.max(1);
568
569    // Stream results as they complete so on_progress fires in real time, not after collect.
570    let mut stream =
571        futures::stream::iter(work_items.into_iter().map(
572            |(key, hash, fut, point_id, payload)| async move {
573                (key, hash, fut.await, point_id, payload)
574            },
575        ))
576        .buffer_unordered(concurrency);
577
578    let mut points_to_upsert = Vec::new();
579    let mut completed: usize = 0;
580    while let Some((key, hash, result, point_id, mut payload)) = stream.next().await {
581        let vector = match result {
582            Ok(v) => v,
583            Err(e) => {
584                tracing::warn!("failed to embed item '{key}': {e:#}");
585                continue;
586            }
587        };
588
589        if let Some(obj) = payload.as_object_mut() {
590            obj.insert(
591                "content_hash".into(),
592                serde_json::Value::String(hash.clone()),
593            );
594            obj.insert(
595                "embedding_model".into(),
596                serde_json::Value::String(embedding_model.to_owned()),
597            );
598        }
599        let payload_map = QdrantOps::json_to_payload(payload)?;
600
601        points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
602
603        if existing.contains_key(&key) {
604            stats.updated += 1;
605        } else {
606            stats.added += 1;
607        }
608        hashes.insert(key, hash);
609
610        completed += 1;
611        if let Some(ref cb) = on_progress {
612            cb(completed, total);
613        }
614    }
615    Ok(points_to_upsert)
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn normalize_no_suffix() {
624        assert_eq!(normalize_model_name("foo"), "foo");
625    }
626
627    #[test]
628    fn normalize_strips_latest() {
629        assert_eq!(normalize_model_name("foo:latest"), "foo");
630    }
631
632    #[test]
633    fn normalize_other_tag_unchanged() {
634        assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
635    }
636
637    struct TestItem {
638        k: String,
639        text: String,
640    }
641
642    impl Embeddable for TestItem {
643        fn key(&self) -> &str {
644            &self.k
645        }
646
647        fn content_hash(&self) -> String {
648            let mut hasher = blake3::Hasher::new();
649            hasher.update(self.text.as_bytes());
650            hasher.finalize().to_hex().to_string()
651        }
652
653        fn embed_text(&self) -> &str {
654            &self.text
655        }
656
657        fn to_payload(&self) -> serde_json::Value {
658            serde_json::json!({"key": self.k, "text": self.text})
659        }
660    }
661
662    fn make_item(k: &str, text: &str) -> TestItem {
663        TestItem {
664            k: k.into(),
665            text: text.into(),
666        }
667    }
668
669    #[test]
670    fn registry_new_valid_url() {
671        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
672        let ns = uuid::Uuid::from_bytes([0u8; 16]);
673        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
674        let dbg = format!("{reg:?}");
675        assert!(dbg.contains("EmbeddingRegistry"));
676        assert!(dbg.contains("test_col"));
677    }
678
679    #[test]
680    fn embeddable_content_hash_deterministic() {
681        let item = make_item("key", "some text");
682        assert_eq!(item.content_hash(), item.content_hash());
683    }
684
685    #[test]
686    fn embeddable_content_hash_changes() {
687        let a = make_item("key", "text a");
688        let b = make_item("key", "text b");
689        assert_ne!(a.content_hash(), b.content_hash());
690    }
691
692    #[test]
693    fn embeddable_payload_contains_key() {
694        let item = make_item("my-key", "desc");
695        let payload = item.to_payload();
696        assert_eq!(payload["key"], "my-key");
697    }
698
699    #[test]
700    fn sync_stats_default() {
701        let s = SyncStats::default();
702        assert_eq!(s.added, 0);
703        assert_eq!(s.updated, 0);
704        assert_eq!(s.removed, 0);
705        assert_eq!(s.unchanged, 0);
706    }
707
708    #[test]
709    fn sync_stats_debug() {
710        let s = SyncStats {
711            added: 1,
712            updated: 2,
713            removed: 3,
714            unchanged: 4,
715        };
716        let dbg = format!("{s:?}");
717        assert!(dbg.contains("added"));
718    }
719
720    #[tokio::test]
721    async fn search_raw_embed_fail_returns_error() {
722        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
723        let ns = uuid::Uuid::from_bytes([0u8; 16]);
724        let reg = EmbeddingRegistry::new(ops, "test", ns);
725        let embed_fn = |_: &str| -> EmbedFuture {
726            Box::pin(async {
727                Err(Box::new(std::io::Error::other("fail"))
728                    as Box<dyn std::error::Error + Send + Sync>)
729            })
730        };
731        let result = reg.search_raw("query", 5, embed_fn).await;
732        assert!(result.is_err());
733    }
734
735    /// Validates the dimension mismatch guard in `search_raw` (issue #3418).
736    ///
737    /// When the cached collection dimension differs from the query vector dimension,
738    /// `search_raw` must return `Err(EmbeddingRegistryError::DimensionProbe)` instead of
739    /// issuing a gRPC search that would silently return near-zero cosine scores.
740    #[tokio::test]
741    async fn search_raw_dimension_mismatch_returns_error() {
742        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
743        let ns = uuid::Uuid::from_bytes([0u8; 16]);
744        let reg = EmbeddingRegistry::new(ops, "test_dim_guard", ns);
745
746        // Simulate that the collection was created with 4-dim vectors.
747        reg.set_cached_dim(4).await;
748
749        // Query with a 2-dim vector (different model / dimension).
750        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
751        let result = reg.search_raw("query", 5, embed_fn).await;
752        assert!(
753            matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
754            "expected DimensionProbe error on dimension mismatch, got: {result:?}"
755        );
756    }
757
758    /// Validates that `search_raw` does not reject a correctly-dimensioned query.
759    ///
760    /// When the cached dimension matches the query vector, the guard must pass and
761    /// the error (if any) comes from the Qdrant network call — not from the guard itself.
762    #[tokio::test]
763    async fn search_raw_matching_dimension_passes_guard() {
764        let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap(); // unreachable — forces network error
765        let ns = uuid::Uuid::from_bytes([0u8; 16]);
766        let reg = EmbeddingRegistry::new(ops, "test_dim_pass", ns);
767
768        // Simulate a 2-dim collection.
769        reg.set_cached_dim(2).await;
770
771        // Query with a matching 2-dim vector.
772        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
773        let result = reg.search_raw("query", 5, embed_fn).await;
774        // The guard passes; the error is from the unreachable Qdrant instance.
775        assert!(
776            !matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
777            "guard must not fire when dimensions match"
778        );
779    }
780
781    #[tokio::test]
782    async fn sync_with_unreachable_qdrant_fails() {
783        let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
784        let ns = uuid::Uuid::from_bytes([0u8; 16]);
785        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
786        let items = vec![make_item("k", "text")];
787        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
788        let result = reg.sync(&items, "model", embed_fn, None).await;
789        assert!(result.is_err());
790    }
791
792    // ── model_has_changed unit tests ──────────────────────────────────────────
793
794    fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
795        let mut point = HashMap::new();
796        point.insert("embedding_model".to_owned(), model.to_owned());
797        let mut map = HashMap::new();
798        map.insert("k1".to_owned(), point);
799        map
800    }
801
802    #[test]
803    fn model_has_changed_latest_vs_bare_is_false() {
804        // Root cause of #2894: stored ":latest" suffix must not trigger recreation.
805        let existing = make_existing("nomic-embed-text-v2-moe:latest");
806        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
807    }
808
809    #[test]
810    fn model_has_changed_same_model_is_false() {
811        let existing = make_existing("nomic-embed-text-v2-moe");
812        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
813    }
814
815    #[test]
816    fn model_has_changed_different_model_is_true() {
817        let existing = make_existing("all-minilm");
818        assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
819    }
820
821    #[test]
822    fn model_has_changed_empty_existing_is_false() {
823        assert!(!model_has_changed(&HashMap::new(), "any-model"));
824    }
825
826    #[test]
827    fn model_has_changed_absent_field_with_config_model_is_true() {
828        // Legacy points have no embedding_model field; treat as mismatch to force recreation.
829        let mut point = HashMap::new();
830        point.insert("content_hash".to_owned(), "abc".to_owned());
831        let mut map = HashMap::new();
832        map.insert("k1".to_owned(), point);
833        assert!(model_has_changed(&map, "nomic-embed-text-v2-moe"));
834    }
835
836    #[test]
837    fn model_has_changed_absent_field_with_empty_config_model_is_false() {
838        let mut point = HashMap::new();
839        point.insert("content_hash".to_owned(), "abc".to_owned());
840        let mut map = HashMap::new();
841        map.insert("k1".to_owned(), point);
842        assert!(!model_has_changed(&map, ""));
843    }
844
845    // ── concurrency guard ─────────────────────────────────────────────────────
846
847    #[test]
848    fn concurrency_zero_clamped_to_one() {
849        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
850        let ns = uuid::Uuid::from_bytes([0u8; 16]);
851        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
852        reg.concurrency = 0;
853        // Clamp is applied inside sync; verify the field itself can be set to 0
854        // and the guard converts it to 1 without panicking (tested via field value).
855        assert_eq!(reg.concurrency.max(1), 1);
856    }
857
858    // ── integration tests (require live Qdrant via testcontainers) ────────────
859
860    /// Test: `on_progress` fires once per successfully embedded item with correct counts.
861    #[tokio::test]
862    #[ignore = "requires Docker for Qdrant"]
863    async fn on_progress_called_once_per_successful_embed() {
864        use std::sync::{
865            Arc,
866            atomic::{AtomicUsize, Ordering},
867        };
868        use testcontainers::GenericImage;
869        use testcontainers::core::{ContainerPort, WaitFor};
870        use testcontainers::runners::AsyncRunner;
871
872        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
873            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
874            .with_wait_for(WaitFor::seconds(1))
875            .with_exposed_port(ContainerPort::Tcp(6334))
876            .start()
877            .await
878            .unwrap();
879        let port = container.get_host_port_ipv4(6334).await.unwrap();
880        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
881        let ns = uuid::Uuid::new_v4();
882        let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
883
884        let items = [
885            make_item("a", "alpha"),
886            make_item("b", "beta"),
887            make_item("c", "gamma"),
888        ];
889        let call_count = Arc::new(AtomicUsize::new(0));
890        let last_done = Arc::new(AtomicUsize::new(0));
891        let last_total = Arc::new(AtomicUsize::new(0));
892        let cc = Arc::clone(&call_count);
893        let ld = Arc::clone(&last_done);
894        let lt = Arc::clone(&last_total);
895
896        let embed_fn =
897            |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
898        let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
899            Some(Box::new(move |completed, total| {
900                cc.fetch_add(1, Ordering::SeqCst);
901                ld.store(completed, Ordering::SeqCst);
902                lt.store(total, Ordering::SeqCst);
903            }));
904
905        let stats = reg
906            .sync(&items, "test-model", embed_fn, on_progress)
907            .await
908            .unwrap();
909        let n = stats.added + stats.updated;
910
911        assert_eq!(
912            call_count.load(Ordering::SeqCst),
913            n,
914            "on_progress call count"
915        );
916        assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
917        assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
918    }
919
920    /// Test: when one embed fails, the batch continues and only successful items are upserted.
921    #[tokio::test]
922    #[ignore = "requires Docker for Qdrant"]
923    async fn partial_embed_failure_skips_failed_item() {
924        use testcontainers::GenericImage;
925        use testcontainers::core::{ContainerPort, WaitFor};
926        use testcontainers::runners::AsyncRunner;
927
928        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
929            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
930            .with_wait_for(WaitFor::seconds(1))
931            .with_exposed_port(ContainerPort::Tcp(6334))
932            .start()
933            .await
934            .unwrap();
935        let port = container.get_host_port_ipv4(6334).await.unwrap();
936        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
937        let ns = uuid::Uuid::new_v4();
938        let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
939
940        // Item whose embed_text contains "fail" will cause the embed_fn to return Err.
941        let items = [
942            make_item("ok1", "ok text"),
943            make_item("fail", "fail text"),
944            make_item("ok2", "ok text 2"),
945        ];
946
947        let embed_fn = |text: &str| -> EmbedFuture {
948            if text.contains("fail") {
949                Box::pin(async {
950                    Err(Box::new(std::io::Error::other("injected failure"))
951                        as Box<dyn std::error::Error + Send + Sync>)
952                })
953            } else {
954                Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
955            }
956        };
957
958        // sync must return Ok — individual failures are warned and skipped.
959        let stats = reg
960            .sync(&items, "test-model", embed_fn, None)
961            .await
962            .unwrap();
963        assert_eq!(
964            stats.added, 2,
965            "two items should be upserted, failed one skipped"
966        );
967    }
968
969    /// Validates the full dimension-mismatch guard path against a live Qdrant instance (issue #3418).
970    ///
971    /// Creates a collection with 4-dim vectors via `sync`, then attempts a search with a 2-dim
972    /// query vector.  The guard in `search_raw` must return `Err(DimensionProbe)` before any
973    /// gRPC call reaches Qdrant, preventing the silent near-zero cosine score failure.
974    #[tokio::test]
975    #[ignore = "requires Docker for Qdrant"]
976    async fn search_raw_dimension_mismatch_returns_error_live() {
977        use testcontainers::GenericImage;
978        use testcontainers::core::{ContainerPort, WaitFor};
979        use testcontainers::runners::AsyncRunner;
980
981        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
982            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
983            .with_wait_for(WaitFor::seconds(1))
984            .with_exposed_port(ContainerPort::Tcp(6334))
985            .start()
986            .await
987            .unwrap();
988        let port = container.get_host_port_ipv4(6334).await.unwrap();
989        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
990        let ns = uuid::Uuid::new_v4();
991        let mut reg = EmbeddingRegistry::new(ops, "test_dim_guard_live", ns);
992
993        // Sync with 4-dim vectors so the collection and cache are established.
994        let items = [make_item("a", "alpha")];
995        let embed_fn_4d =
996            |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0, 0.0, 0.0]) }) };
997        reg.sync(&items, "model-4d", embed_fn_4d, None)
998            .await
999            .unwrap();
1000
1001        // Search with a 2-dim query (simulates a model switch without re-sync).
1002        let embed_fn_2d = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
1003        let result = reg.search_raw("query", 5, embed_fn_2d).await;
1004        assert!(
1005            matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
1006            "dimension mismatch must return DimensionProbe error, not silent near-zero scores; got: {result:?}"
1007        );
1008    }
1009}