Skip to main content

zeph_memory/
embedding_registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Generic embedding registry backed by Qdrant.
5//!
6//! Provides deduplication through content-hash delta tracking and collection-level
7//! embedding-model change detection.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use tokio::sync::RwLock;
15
16use futures::StreamExt as _;
17use qdrant_client::qdrant::{PointStruct, value::Kind};
18
19use crate::QdrantOps;
20use crate::vector_store::VectorStoreError;
21
22/// Boxed future returned by an embedding function.
23pub type EmbedFuture = Pin<
24    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
25>;
26
27/// Domain type that can be stored in an [`EmbeddingRegistry`].
28///
29/// Implement this trait for any struct that should be embedded and persisted in Qdrant.
30/// The registry uses [`key`](Self::key) and [`content_hash`](Self::content_hash) to
31/// detect which items need to be re-embedded on each [`EmbeddingRegistry::sync`] call.
32pub trait Embeddable: Send + Sync {
33    /// Unique string key used for point-ID generation and delta tracking.
34    fn key(&self) -> &str;
35
36    /// BLAKE3 hex hash of all semantically relevant fields.
37    ///
38    /// When this hash changes between syncs the item's embedding is recomputed.
39    fn content_hash(&self) -> String;
40
41    /// Text that will be passed to the embedding model.
42    fn embed_text(&self) -> &str;
43
44    /// Full JSON payload to store in Qdrant alongside the vector.
45    ///
46    /// **Must** include a `"key"` field equal to [`Self::key()`] so
47    /// [`EmbeddingRegistry`] can recover items on scroll.
48    fn to_payload(&self) -> serde_json::Value;
49}
50
51/// Counters returned by [`EmbeddingRegistry::sync`].
52#[derive(Debug, Default, Clone)]
53pub struct SyncStats {
54    pub added: usize,
55    pub updated: usize,
56    pub removed: usize,
57    pub unchanged: usize,
58}
59
60/// Errors produced by [`EmbeddingRegistry`].
61#[derive(Debug, thiserror::Error)]
62pub enum EmbeddingRegistryError {
63    #[error("vector store error: {0}")]
64    VectorStore(#[from] VectorStoreError),
65
66    #[error("embedding error: {0}")]
67    Embedding(String),
68
69    #[error("serialization error: {0}")]
70    Serialization(String),
71
72    #[error("dimension probe failed: {0}")]
73    DimensionProbe(String),
74}
75
76impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
77    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
78        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
79    }
80}
81
82impl From<serde_json::Error> for EmbeddingRegistryError {
83    fn from(e: serde_json::Error) -> Self {
84        Self::Serialization(e.to_string())
85    }
86}
87
88// Ollama appends :latest when no tag is specified; treat the two as equivalent.
89fn normalize_model_name(name: &str) -> &str {
90    name.strip_suffix(":latest").unwrap_or(name)
91}
92
93/// Returns `true` when any stored point uses a model name that is semantically different
94/// from `config_model` after normalizing `:latest` suffixes.
95///
96/// A missing `embedding_model` field (legacy points from pre-#3395 sessions) is treated as a
97/// mismatch: the vector was produced by an unknown model and must be regenerated.
98fn model_has_changed(
99    existing: &HashMap<String, HashMap<String, String>>,
100    config_model: &str,
101) -> bool {
102    if config_model.is_empty() {
103        return false;
104    }
105    existing
106        .values()
107        .any(|stored| match stored.get("embedding_model") {
108            Some(m) => normalize_model_name(m) != normalize_model_name(config_model),
109            // Absent field means the point was written before the model was recorded; treat as mismatch.
110            None => true,
111        })
112}
113
114/// Generic Qdrant-backed embedding registry.
115///
116/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
117/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
118/// O(1) delta detection between syncs.
119///
120/// The `cached_dim` field caches the collection's vector dimension after the first successful
121/// [`sync`](Self::sync) so that [`search_raw`](Self::search_raw) can validate the query vector
122/// dimension without an extra Qdrant round-trip on every call.  When a mismatch is detected,
123/// `search_raw` returns [`EmbeddingRegistryError::DimensionProbe`] instead of silently issuing a
124/// gRPC search that would return near-zero cosine scores (Qdrant gRPC behaviour on dim mismatch).
125#[derive(Clone)]
126pub struct EmbeddingRegistry {
127    ops: QdrantOps,
128    collection: String,
129    namespace: uuid::Uuid,
130    hashes: HashMap<String, String>,
131    /// Maximum number of embedding requests dispatched concurrently during a sync.
132    pub concurrency: usize,
133    /// Vector dimension confirmed during the last successful `sync`.  Shared via `Arc` so
134    /// `Clone` works without invalidating the cached value across cloned instances.
135    cached_dim: Arc<RwLock<Option<u64>>>,
136}
137
138impl std::fmt::Debug for EmbeddingRegistry {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("EmbeddingRegistry")
141            .field("collection", &self.collection)
142            .finish_non_exhaustive()
143    }
144}
145
146impl EmbeddingRegistry {
147    /// Create a registry wrapping an existing [`QdrantOps`] connection.
148    #[must_use]
149    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
150        Self {
151            ops,
152            collection: collection.into(),
153            namespace,
154            hashes: HashMap::new(),
155            concurrency: 4,
156            cached_dim: Arc::new(RwLock::new(None)),
157        }
158    }
159
160    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
161    /// unnecessary re-embedding.  Re-creates the collection when the embedding
162    /// model changes.
163    ///
164    /// `on_progress`, when provided, is called after each successful embed+upsert with
165    /// `(completed, total)` counts so callers can display progress indicators.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
170    pub async fn sync<T: Embeddable>(
171        &mut self,
172        items: &[T],
173        embedding_model: &str,
174        embed_fn: impl Fn(&str) -> EmbedFuture,
175        on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
176    ) -> Result<SyncStats, EmbeddingRegistryError> {
177        let mut stats = SyncStats::default();
178
179        self.ensure_collection(&embed_fn).await?;
180
181        let existing = self
182            .ops
183            .scroll_all(&self.collection, "key")
184            .await
185            .map_err(|e| {
186                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
187            })?;
188
189        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
190        for item in items {
191            current.insert(item.key().to_owned(), (item.content_hash(), item));
192        }
193
194        let model_changed = model_has_changed(&existing, embedding_model);
195
196        if model_changed {
197            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
198            self.recreate_collection(&embed_fn).await?;
199        }
200
201        let work_items = build_work_set(
202            &current,
203            &existing,
204            model_changed,
205            &mut stats,
206            &mut self.hashes,
207        );
208
209        // Pre-create futures, point IDs, and payloads before taking the mutable borrow on
210        // self.hashes to avoid a double-borrow on `self`.
211        let work_with_futures: Vec<(String, String, EmbedFuture, String, serde_json::Value)> =
212            work_items
213                .into_iter()
214                .map(|(key, hash, item)| {
215                    let text = item.embed_text().to_owned();
216                    let fut = embed_fn(&text);
217                    let point_id = self.point_id(&key);
218                    let payload = item.to_payload();
219                    (key, hash, fut, point_id, payload)
220                })
221                .collect();
222
223        let points_to_upsert = embed_and_collect_points(
224            work_with_futures,
225            on_progress,
226            &existing,
227            embedding_model,
228            self.concurrency,
229            &mut stats,
230            &mut self.hashes,
231        )
232        .await?;
233
234        if !points_to_upsert.is_empty() {
235            self.ops
236                .upsert(&self.collection, points_to_upsert)
237                .await
238                .map_err(|e| {
239                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
240                })?;
241        }
242
243        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
244            .keys()
245            .filter(|key| !current.contains_key(*key))
246            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
247            .collect();
248
249        if !orphan_ids.is_empty() {
250            stats.removed = orphan_ids.len();
251            self.ops
252                .delete_by_ids(&self.collection, orphan_ids)
253                .await
254                .map_err(|e| {
255                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
256                })?;
257        }
258
259        tracing::info!(
260            added = stats.added,
261            updated = stats.updated,
262            removed = stats.removed,
263            unchanged = stats.unchanged,
264            collection = &self.collection,
265            "embeddings synced"
266        );
267
268        Ok(stats)
269    }
270
271    /// Search the collection, returning raw scored Qdrant points.
272    ///
273    /// Validates that the query vector dimension matches the collection before issuing the gRPC
274    /// call.  Qdrant gRPC silently returns near-zero cosine scores (~0.022) when dimensions
275    /// mismatch instead of returning an error — this guard prevents that silent failure.
276    ///
277    /// The dimension is checked against the cache populated by the most recent [`sync`](Self::sync)
278    /// call.  If no sync has occurred (cache is `None`) the check is skipped to avoid blocking
279    /// reads before the first sync.
280    ///
281    /// Consumers map the payloads to their domain types.
282    ///
283    /// # Errors
284    ///
285    /// Returns [`EmbeddingRegistryError::DimensionProbe`] when the query vector dimension does not
286    /// match the stored collection dimension.  Returns [`EmbeddingRegistryError::Embedding`] if the
287    /// embed function fails, or [`EmbeddingRegistryError::VectorStore`] on Qdrant search failure.
288    pub async fn search_raw(
289        &self,
290        query: &str,
291        limit: usize,
292        embed_fn: impl Fn(&str) -> EmbedFuture,
293    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
294        let query_vec = embed_fn(query)
295            .await
296            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
297
298        // Guard: Qdrant gRPC returns near-zero cosine scores when the query vector dimension
299        // does not match the stored collection dimension (issue #3418).  Check the cache first
300        // (populated by sync); fall back to a live Qdrant probe only when the cache is empty.
301        let collection_dim: Option<u64> = *self.cached_dim.read().await;
302
303        let collection_dim = if collection_dim.is_some() {
304            collection_dim
305        } else {
306            // Cache miss: ask Qdrant directly (first search before any sync), then populate cache.
307            let probed = self
308                .ops
309                .get_collection_vector_size(&self.collection)
310                .await
311                .map_err(|e| {
312                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
313                })?;
314            if let Some(d) = probed {
315                self.set_cached_dim(d).await;
316            }
317            probed
318        };
319
320        if let Some(stored_dim) = collection_dim {
321            // Safe: a Vec<f32> with 4B+ elements is impossible in practice on any 64-bit platform.
322            let query_dim = query_vec.len() as u64;
323            if query_dim != stored_dim {
324                return Err(EmbeddingRegistryError::DimensionProbe(format!(
325                    "query vector dimension {query_dim} does not match collection '{}' \
326                     dimension {stored_dim}; re-run sync to rebuild the collection",
327                    self.collection
328                )));
329            }
330        }
331
332        let Ok(limit_u64) = u64::try_from(limit) else {
333            return Ok(Vec::new());
334        };
335
336        let results = self
337            .ops
338            .search(&self.collection, query_vec, limit_u64, None)
339            .await
340            .map_err(|e| {
341                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
342            })?;
343
344        let scored: Vec<crate::ScoredVectorPoint> = results
345            .into_iter()
346            .map(|point| {
347                let payload: HashMap<String, serde_json::Value> = point
348                    .payload
349                    .into_iter()
350                    .filter_map(|(k, v)| {
351                        let json_val = match v.kind? {
352                            Kind::StringValue(s) => serde_json::Value::String(s),
353                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
354                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
355                            Kind::DoubleValue(d) => {
356                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
357                            }
358                            _ => return None,
359                        };
360                        Some((k, json_val))
361                    })
362                    .collect();
363
364                let id = match point.id.and_then(|pid| pid.point_id_options) {
365                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
366                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
367                    None => String::new(),
368                };
369
370                crate::ScoredVectorPoint {
371                    id,
372                    score: point.score,
373                    payload,
374                }
375            })
376            .collect();
377
378        Ok(scored)
379    }
380
381    fn point_id(&self, key: &str) -> String {
382        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
383    }
384
385    async fn ensure_collection(
386        &self,
387        embed_fn: &impl Fn(&str) -> EmbedFuture,
388    ) -> Result<(), EmbeddingRegistryError> {
389        if !self
390            .ops
391            .collection_exists(&self.collection)
392            .await
393            .map_err(|e| {
394                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
395            })?
396        {
397            // Collection does not exist — probe once and create.
398            let vector_size = self.probe_vector_size(embed_fn).await?;
399            self.ops
400                .ensure_collection(&self.collection, vector_size)
401                .await
402                .map_err(|e| {
403                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
404                })?;
405            tracing::info!(
406                collection = &self.collection,
407                dimensions = vector_size,
408                "created Qdrant collection"
409            );
410            self.set_cached_dim(vector_size).await;
411            return Ok(());
412        }
413
414        let existing_size = self
415            .ops
416            .client()
417            .collection_info(&self.collection)
418            .await
419            .map_err(|e| {
420                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
421            })?
422            .result
423            .and_then(|info| info.config)
424            .and_then(|cfg| cfg.params)
425            .and_then(|params| params.vectors_config)
426            .and_then(|vc| vc.config)
427            .and_then(|cfg| match cfg {
428                qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
429                // Named-vector collections (ParamsMap) are not supported by this registry;
430                // treat size as unknown and recreate to ensure a compatible single-vector layout.
431                qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
432            });
433
434        let vector_size = self.probe_vector_size(embed_fn).await?;
435
436        if existing_size == Some(vector_size) {
437            self.set_cached_dim(vector_size).await;
438            return Ok(());
439        }
440
441        tracing::warn!(
442            collection = &self.collection,
443            existing = ?existing_size,
444            required = vector_size,
445            "vector dimension mismatch, recreating collection"
446        );
447        self.ops
448            .delete_collection(&self.collection)
449            .await
450            .map_err(|e| {
451                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
452            })?;
453        self.ops
454            .ensure_collection(&self.collection, vector_size)
455            .await
456            .map_err(|e| {
457                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
458            })?;
459        tracing::info!(
460            collection = &self.collection,
461            dimensions = vector_size,
462            "created Qdrant collection"
463        );
464        self.set_cached_dim(vector_size).await;
465
466        Ok(())
467    }
468
469    /// Store `dim` in the dimension cache so `search_raw` can validate without a Qdrant round-trip.
470    async fn set_cached_dim(&self, dim: u64) {
471        *self.cached_dim.write().await = Some(dim);
472    }
473
474    async fn probe_vector_size(
475        &self,
476        embed_fn: &impl Fn(&str) -> EmbedFuture,
477    ) -> Result<u64, EmbeddingRegistryError> {
478        let probe = embed_fn("dimension probe")
479            .await
480            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
481        // Safe: a Vec<f32> with 4B+ elements is impossible in practice on any 64-bit platform.
482        Ok(probe.len() as u64)
483    }
484
485    async fn recreate_collection(
486        &self,
487        embed_fn: &impl Fn(&str) -> EmbedFuture,
488    ) -> Result<(), EmbeddingRegistryError> {
489        if self
490            .ops
491            .collection_exists(&self.collection)
492            .await
493            .map_err(|e| {
494                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
495            })?
496        {
497            self.ops
498                .delete_collection(&self.collection)
499                .await
500                .map_err(|e| {
501                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
502                })?;
503            tracing::info!(
504                collection = &self.collection,
505                "deleted collection for recreation"
506            );
507        }
508        self.ensure_collection(embed_fn).await
509    }
510}
511
512/// Determine which items need embedding and update stats for unchanged ones.
513///
514/// Returns a list of `(key, hash, item)` triples that require re-embedding.  Items whose
515/// stored hash matches the current hash are counted as `unchanged` in `stats` and their
516/// hashes are pre-populated in the `hashes` map.
517fn build_work_set<'a, T: Embeddable>(
518    current: &HashMap<String, (String, &'a T)>,
519    existing: &HashMap<String, HashMap<String, String>>,
520    model_changed: bool,
521    stats: &mut SyncStats,
522    hashes: &mut HashMap<String, String>,
523) -> Vec<(String, String, &'a T)> {
524    let mut work_items: Vec<(String, String, &'a T)> = Vec::new();
525    for (key, (hash, item)) in current {
526        let needs_update = if let Some(stored) = existing.get(key) {
527            model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
528        } else {
529            true
530        };
531
532        if needs_update {
533            work_items.push((key.clone(), hash.clone(), *item));
534        } else {
535            stats.unchanged += 1;
536            hashes.insert(key.clone(), hash.clone());
537        }
538    }
539    work_items
540}
541
542/// Await each pre-created embed future and collect the resulting Qdrant points.
543///
544/// Await each pre-created embed future and collect the resulting Qdrant points.
545///
546/// `work_items` is `(key, hash, embed_future, point_id, item_payload)` — point IDs and payloads
547/// must be pre-computed to avoid a double-borrow on the `EmbeddingRegistry` when `hashes` is
548/// mutably borrowed.
549///
550/// Processes futures with bounded concurrency (`concurrency` parameter).  Calls `on_progress`
551/// after each successful embed.  Updates `stats.added`/`stats.updated` and `hashes` in place.
552///
553/// Returns a `Vec<PointStruct>` ready for upsert, or an error if payload serialization fails.
554#[allow(clippy::too_many_arguments)]
555async fn embed_and_collect_points(
556    work_items: Vec<(String, String, EmbedFuture, String, serde_json::Value)>,
557    on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
558    existing: &HashMap<String, HashMap<String, String>>,
559    embedding_model: &str,
560    concurrency: usize,
561    stats: &mut SyncStats,
562    hashes: &mut HashMap<String, String>,
563) -> Result<Vec<PointStruct>, EmbeddingRegistryError> {
564    let total = work_items.len();
565    // Clamp concurrency to at least 1: buffer_unordered(0) silently skips all futures.
566    let concurrency = concurrency.max(1);
567
568    // Stream results as they complete so on_progress fires in real time, not after collect.
569    let mut stream =
570        futures::stream::iter(work_items.into_iter().map(
571            |(key, hash, fut, point_id, payload)| async move {
572                (key, hash, fut.await, point_id, payload)
573            },
574        ))
575        .buffer_unordered(concurrency);
576
577    let mut points_to_upsert = Vec::new();
578    let mut completed: usize = 0;
579    while let Some((key, hash, result, point_id, mut payload)) = stream.next().await {
580        let vector = match result {
581            Ok(v) => v,
582            Err(e) => {
583                tracing::warn!("failed to embed item '{key}': {e:#}");
584                continue;
585            }
586        };
587
588        if let Some(obj) = payload.as_object_mut() {
589            obj.insert(
590                "content_hash".into(),
591                serde_json::Value::String(hash.clone()),
592            );
593            obj.insert(
594                "embedding_model".into(),
595                serde_json::Value::String(embedding_model.to_owned()),
596            );
597        }
598        let payload_map = QdrantOps::json_to_payload(payload)?;
599
600        points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
601
602        if existing.contains_key(&key) {
603            stats.updated += 1;
604        } else {
605            stats.added += 1;
606        }
607        hashes.insert(key, hash);
608
609        completed += 1;
610        if let Some(ref cb) = on_progress {
611            cb(completed, total);
612        }
613    }
614    Ok(points_to_upsert)
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn normalize_no_suffix() {
623        assert_eq!(normalize_model_name("foo"), "foo");
624    }
625
626    #[test]
627    fn normalize_strips_latest() {
628        assert_eq!(normalize_model_name("foo:latest"), "foo");
629    }
630
631    #[test]
632    fn normalize_other_tag_unchanged() {
633        assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
634    }
635
636    struct TestItem {
637        k: String,
638        text: String,
639    }
640
641    impl Embeddable for TestItem {
642        fn key(&self) -> &str {
643            &self.k
644        }
645
646        fn content_hash(&self) -> String {
647            let mut hasher = blake3::Hasher::new();
648            hasher.update(self.text.as_bytes());
649            hasher.finalize().to_hex().to_string()
650        }
651
652        fn embed_text(&self) -> &str {
653            &self.text
654        }
655
656        fn to_payload(&self) -> serde_json::Value {
657            serde_json::json!({"key": self.k, "text": self.text})
658        }
659    }
660
661    fn make_item(k: &str, text: &str) -> TestItem {
662        TestItem {
663            k: k.into(),
664            text: text.into(),
665        }
666    }
667
668    #[test]
669    fn registry_new_valid_url() {
670        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
671        let ns = uuid::Uuid::from_bytes([0u8; 16]);
672        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
673        let dbg = format!("{reg:?}");
674        assert!(dbg.contains("EmbeddingRegistry"));
675        assert!(dbg.contains("test_col"));
676    }
677
678    #[test]
679    fn embeddable_content_hash_deterministic() {
680        let item = make_item("key", "some text");
681        assert_eq!(item.content_hash(), item.content_hash());
682    }
683
684    #[test]
685    fn embeddable_content_hash_changes() {
686        let a = make_item("key", "text a");
687        let b = make_item("key", "text b");
688        assert_ne!(a.content_hash(), b.content_hash());
689    }
690
691    #[test]
692    fn embeddable_payload_contains_key() {
693        let item = make_item("my-key", "desc");
694        let payload = item.to_payload();
695        assert_eq!(payload["key"], "my-key");
696    }
697
698    #[test]
699    fn sync_stats_default() {
700        let s = SyncStats::default();
701        assert_eq!(s.added, 0);
702        assert_eq!(s.updated, 0);
703        assert_eq!(s.removed, 0);
704        assert_eq!(s.unchanged, 0);
705    }
706
707    #[test]
708    fn sync_stats_debug() {
709        let s = SyncStats {
710            added: 1,
711            updated: 2,
712            removed: 3,
713            unchanged: 4,
714        };
715        let dbg = format!("{s:?}");
716        assert!(dbg.contains("added"));
717    }
718
719    #[tokio::test]
720    async fn search_raw_embed_fail_returns_error() {
721        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
722        let ns = uuid::Uuid::from_bytes([0u8; 16]);
723        let reg = EmbeddingRegistry::new(ops, "test", ns);
724        let embed_fn = |_: &str| -> EmbedFuture {
725            Box::pin(async {
726                Err(Box::new(std::io::Error::other("fail"))
727                    as Box<dyn std::error::Error + Send + Sync>)
728            })
729        };
730        let result = reg.search_raw("query", 5, embed_fn).await;
731        assert!(result.is_err());
732    }
733
734    /// Validates the dimension mismatch guard in `search_raw` (issue #3418).
735    ///
736    /// When the cached collection dimension differs from the query vector dimension,
737    /// `search_raw` must return `Err(EmbeddingRegistryError::DimensionProbe)` instead of
738    /// issuing a gRPC search that would silently return near-zero cosine scores.
739    #[tokio::test]
740    async fn search_raw_dimension_mismatch_returns_error() {
741        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
742        let ns = uuid::Uuid::from_bytes([0u8; 16]);
743        let reg = EmbeddingRegistry::new(ops, "test_dim_guard", ns);
744
745        // Simulate that the collection was created with 4-dim vectors.
746        reg.set_cached_dim(4).await;
747
748        // Query with a 2-dim vector (different model / dimension).
749        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
750        let result = reg.search_raw("query", 5, embed_fn).await;
751        assert!(
752            matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
753            "expected DimensionProbe error on dimension mismatch, got: {result:?}"
754        );
755    }
756
757    /// Validates that `search_raw` does not reject a correctly-dimensioned query.
758    ///
759    /// When the cached dimension matches the query vector, the guard must pass and
760    /// the error (if any) comes from the Qdrant network call — not from the guard itself.
761    #[tokio::test]
762    async fn search_raw_matching_dimension_passes_guard() {
763        let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap(); // unreachable — forces network error
764        let ns = uuid::Uuid::from_bytes([0u8; 16]);
765        let reg = EmbeddingRegistry::new(ops, "test_dim_pass", ns);
766
767        // Simulate a 2-dim collection.
768        reg.set_cached_dim(2).await;
769
770        // Query with a matching 2-dim vector.
771        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
772        let result = reg.search_raw("query", 5, embed_fn).await;
773        // The guard passes; the error is from the unreachable Qdrant instance.
774        assert!(
775            !matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
776            "guard must not fire when dimensions match"
777        );
778    }
779
780    #[tokio::test]
781    async fn sync_with_unreachable_qdrant_fails() {
782        let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
783        let ns = uuid::Uuid::from_bytes([0u8; 16]);
784        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
785        let items = vec![make_item("k", "text")];
786        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
787        let result = reg.sync(&items, "model", embed_fn, None).await;
788        assert!(result.is_err());
789    }
790
791    // ── model_has_changed unit tests ──────────────────────────────────────────
792
793    fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
794        let mut point = HashMap::new();
795        point.insert("embedding_model".to_owned(), model.to_owned());
796        let mut map = HashMap::new();
797        map.insert("k1".to_owned(), point);
798        map
799    }
800
801    #[test]
802    fn model_has_changed_latest_vs_bare_is_false() {
803        // Root cause of #2894: stored ":latest" suffix must not trigger recreation.
804        let existing = make_existing("nomic-embed-text-v2-moe:latest");
805        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
806    }
807
808    #[test]
809    fn model_has_changed_same_model_is_false() {
810        let existing = make_existing("nomic-embed-text-v2-moe");
811        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
812    }
813
814    #[test]
815    fn model_has_changed_different_model_is_true() {
816        let existing = make_existing("all-minilm");
817        assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
818    }
819
820    #[test]
821    fn model_has_changed_empty_existing_is_false() {
822        assert!(!model_has_changed(&HashMap::new(), "any-model"));
823    }
824
825    #[test]
826    fn model_has_changed_absent_field_with_config_model_is_true() {
827        // Legacy points have no embedding_model field; treat as mismatch to force recreation.
828        let mut point = HashMap::new();
829        point.insert("content_hash".to_owned(), "abc".to_owned());
830        let mut map = HashMap::new();
831        map.insert("k1".to_owned(), point);
832        assert!(model_has_changed(&map, "nomic-embed-text-v2-moe"));
833    }
834
835    #[test]
836    fn model_has_changed_absent_field_with_empty_config_model_is_false() {
837        let mut point = HashMap::new();
838        point.insert("content_hash".to_owned(), "abc".to_owned());
839        let mut map = HashMap::new();
840        map.insert("k1".to_owned(), point);
841        assert!(!model_has_changed(&map, ""));
842    }
843
844    // ── concurrency guard ─────────────────────────────────────────────────────
845
846    #[test]
847    fn concurrency_zero_clamped_to_one() {
848        let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
849        let ns = uuid::Uuid::from_bytes([0u8; 16]);
850        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
851        reg.concurrency = 0;
852        // Clamp is applied inside sync; verify the field itself can be set to 0
853        // and the guard converts it to 1 without panicking (tested via field value).
854        assert_eq!(reg.concurrency.max(1), 1);
855    }
856
857    // ── integration tests (require live Qdrant via testcontainers) ────────────
858
859    /// Test: `on_progress` fires once per successfully embedded item with correct counts.
860    #[tokio::test]
861    #[ignore = "requires Docker for Qdrant"]
862    async fn on_progress_called_once_per_successful_embed() {
863        use std::sync::{
864            Arc,
865            atomic::{AtomicUsize, Ordering},
866        };
867        use testcontainers::GenericImage;
868        use testcontainers::core::{ContainerPort, WaitFor};
869        use testcontainers::runners::AsyncRunner;
870
871        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
872            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
873            .with_wait_for(WaitFor::seconds(1))
874            .with_exposed_port(ContainerPort::Tcp(6334))
875            .start()
876            .await
877            .unwrap();
878        let port = container.get_host_port_ipv4(6334).await.unwrap();
879        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
880        let ns = uuid::Uuid::new_v4();
881        let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
882
883        let items = [
884            make_item("a", "alpha"),
885            make_item("b", "beta"),
886            make_item("c", "gamma"),
887        ];
888        let call_count = Arc::new(AtomicUsize::new(0));
889        let last_done = Arc::new(AtomicUsize::new(0));
890        let last_total = Arc::new(AtomicUsize::new(0));
891        let cc = Arc::clone(&call_count);
892        let ld = Arc::clone(&last_done);
893        let lt = Arc::clone(&last_total);
894
895        let embed_fn =
896            |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
897        let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
898            Some(Box::new(move |completed, total| {
899                cc.fetch_add(1, Ordering::SeqCst);
900                ld.store(completed, Ordering::SeqCst);
901                lt.store(total, Ordering::SeqCst);
902            }));
903
904        let stats = reg
905            .sync(&items, "test-model", embed_fn, on_progress)
906            .await
907            .unwrap();
908        let n = stats.added + stats.updated;
909
910        assert_eq!(
911            call_count.load(Ordering::SeqCst),
912            n,
913            "on_progress call count"
914        );
915        assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
916        assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
917    }
918
919    /// Test: when one embed fails, the batch continues and only successful items are upserted.
920    #[tokio::test]
921    #[ignore = "requires Docker for Qdrant"]
922    async fn partial_embed_failure_skips_failed_item() {
923        use testcontainers::GenericImage;
924        use testcontainers::core::{ContainerPort, WaitFor};
925        use testcontainers::runners::AsyncRunner;
926
927        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
928            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
929            .with_wait_for(WaitFor::seconds(1))
930            .with_exposed_port(ContainerPort::Tcp(6334))
931            .start()
932            .await
933            .unwrap();
934        let port = container.get_host_port_ipv4(6334).await.unwrap();
935        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
936        let ns = uuid::Uuid::new_v4();
937        let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
938
939        // Item whose embed_text contains "fail" will cause the embed_fn to return Err.
940        let items = [
941            make_item("ok1", "ok text"),
942            make_item("fail", "fail text"),
943            make_item("ok2", "ok text 2"),
944        ];
945
946        let embed_fn = |text: &str| -> EmbedFuture {
947            if text.contains("fail") {
948                Box::pin(async {
949                    Err(Box::new(std::io::Error::other("injected failure"))
950                        as Box<dyn std::error::Error + Send + Sync>)
951                })
952            } else {
953                Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
954            }
955        };
956
957        // sync must return Ok — individual failures are warned and skipped.
958        let stats = reg
959            .sync(&items, "test-model", embed_fn, None)
960            .await
961            .unwrap();
962        assert_eq!(
963            stats.added, 2,
964            "two items should be upserted, failed one skipped"
965        );
966    }
967
968    /// Validates the full dimension-mismatch guard path against a live Qdrant instance (issue #3418).
969    ///
970    /// Creates a collection with 4-dim vectors via `sync`, then attempts a search with a 2-dim
971    /// query vector.  The guard in `search_raw` must return `Err(DimensionProbe)` before any
972    /// gRPC call reaches Qdrant, preventing the silent near-zero cosine score failure.
973    #[tokio::test]
974    #[ignore = "requires Docker for Qdrant"]
975    async fn search_raw_dimension_mismatch_returns_error_live() {
976        use testcontainers::GenericImage;
977        use testcontainers::core::{ContainerPort, WaitFor};
978        use testcontainers::runners::AsyncRunner;
979
980        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
981            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
982            .with_wait_for(WaitFor::seconds(1))
983            .with_exposed_port(ContainerPort::Tcp(6334))
984            .start()
985            .await
986            .unwrap();
987        let port = container.get_host_port_ipv4(6334).await.unwrap();
988        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
989        let ns = uuid::Uuid::new_v4();
990        let mut reg = EmbeddingRegistry::new(ops, "test_dim_guard_live", ns);
991
992        // Sync with 4-dim vectors so the collection and cache are established.
993        let items = [make_item("a", "alpha")];
994        let embed_fn_4d =
995            |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0, 0.0, 0.0]) }) };
996        reg.sync(&items, "model-4d", embed_fn_4d, None)
997            .await
998            .unwrap();
999
1000        // Search with a 2-dim query (simulates a model switch without re-sync).
1001        let embed_fn_2d = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
1002        let result = reg.search_raw("query", 5, embed_fn_2d).await;
1003        assert!(
1004            matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
1005            "dimension mismatch must return DimensionProbe error, not silent near-zero scores; got: {result:?}"
1006        );
1007    }
1008}