Skip to main content

zeph_memory/
embedding_registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Generic embedding registry backed by Qdrant.
5//!
6//! Provides deduplication through content-hash delta tracking and collection-level
7//! embedding-model change detection.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use futures::StreamExt as _;
14use qdrant_client::qdrant::{PointStruct, value::Kind};
15
16use crate::QdrantOps;
17use crate::vector_store::VectorStoreError;
18
19/// Boxed future returned by an embedding function.
20pub type EmbedFuture = Pin<
21    Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
22>;
23
24/// Domain type that can be stored in an [`EmbeddingRegistry`].
25///
26/// Implement this trait for any struct that should be embedded and persisted in Qdrant.
27/// The registry uses [`key`](Self::key) and [`content_hash`](Self::content_hash) to
28/// detect which items need to be re-embedded on each [`EmbeddingRegistry::sync`] call.
29pub trait Embeddable: Send + Sync {
30    /// Unique string key used for point-ID generation and delta tracking.
31    fn key(&self) -> &str;
32
33    /// BLAKE3 hex hash of all semantically relevant fields.
34    ///
35    /// When this hash changes between syncs the item's embedding is recomputed.
36    fn content_hash(&self) -> String;
37
38    /// Text that will be passed to the embedding model.
39    fn embed_text(&self) -> &str;
40
41    /// Full JSON payload to store in Qdrant alongside the vector.
42    ///
43    /// **Must** include a `"key"` field equal to [`Self::key()`] so
44    /// [`EmbeddingRegistry`] can recover items on scroll.
45    fn to_payload(&self) -> serde_json::Value;
46}
47
48/// Counters returned by [`EmbeddingRegistry::sync`].
49#[derive(Debug, Default, Clone)]
50pub struct SyncStats {
51    pub added: usize,
52    pub updated: usize,
53    pub removed: usize,
54    pub unchanged: usize,
55}
56
57/// Errors produced by [`EmbeddingRegistry`].
58#[derive(Debug, thiserror::Error)]
59pub enum EmbeddingRegistryError {
60    #[error("vector store error: {0}")]
61    VectorStore(#[from] VectorStoreError),
62
63    #[error("embedding error: {0}")]
64    Embedding(String),
65
66    #[error("serialization error: {0}")]
67    Serialization(String),
68
69    #[error("dimension probe failed: {0}")]
70    DimensionProbe(String),
71}
72
73impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
74    fn from(e: Box<qdrant_client::QdrantError>) -> Self {
75        Self::VectorStore(VectorStoreError::Collection(e.to_string()))
76    }
77}
78
79impl From<serde_json::Error> for EmbeddingRegistryError {
80    fn from(e: serde_json::Error) -> Self {
81        Self::Serialization(e.to_string())
82    }
83}
84
85impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
86    fn from(e: std::num::TryFromIntError) -> Self {
87        Self::DimensionProbe(e.to_string())
88    }
89}
90
91// Ollama appends :latest when no tag is specified; treat the two as equivalent.
92fn normalize_model_name(name: &str) -> &str {
93    name.strip_suffix(":latest").unwrap_or(name)
94}
95
96/// Returns `true` when any stored point uses a model name that is semantically different
97/// from `config_model` after normalizing `:latest` suffixes.
98fn model_has_changed(
99    existing: &HashMap<String, HashMap<String, String>>,
100    config_model: &str,
101) -> bool {
102    existing.values().any(|stored| {
103        stored
104            .get("embedding_model")
105            .is_some_and(|m| normalize_model_name(m) != normalize_model_name(config_model))
106    })
107}
108
109/// Generic Qdrant-backed embedding registry.
110///
111/// Owns a [`QdrantOps`] instance, a collection name and a UUID namespace for
112/// deterministic point IDs (uuid v5).  The in-memory `hashes` map enables
113/// O(1) delta detection between syncs.
114#[derive(Clone)]
115pub struct EmbeddingRegistry {
116    ops: QdrantOps,
117    collection: String,
118    namespace: uuid::Uuid,
119    hashes: HashMap<String, String>,
120    /// Maximum number of embedding requests dispatched concurrently during a sync.
121    pub concurrency: usize,
122}
123
124impl std::fmt::Debug for EmbeddingRegistry {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("EmbeddingRegistry")
127            .field("collection", &self.collection)
128            .finish_non_exhaustive()
129    }
130}
131
132impl EmbeddingRegistry {
133    /// Create a registry wrapping an existing [`QdrantOps`] connection.
134    #[must_use]
135    pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
136        Self {
137            ops,
138            collection: collection.into(),
139            namespace,
140            hashes: HashMap::new(),
141            concurrency: 4,
142        }
143    }
144
145    /// Sync `items` into Qdrant, computing a content-hash delta to avoid
146    /// unnecessary re-embedding.  Re-creates the collection when the embedding
147    /// model changes.
148    ///
149    /// `on_progress`, when provided, is called after each successful embed+upsert with
150    /// `(completed, total)` counts so callers can display progress indicators.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`EmbeddingRegistryError`] on Qdrant or embedding failures.
155    #[allow(clippy::too_many_lines)]
156    pub async fn sync<T: Embeddable>(
157        &mut self,
158        items: &[T],
159        embedding_model: &str,
160        embed_fn: impl Fn(&str) -> EmbedFuture,
161        on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
162    ) -> Result<SyncStats, EmbeddingRegistryError> {
163        let mut stats = SyncStats::default();
164
165        self.ensure_collection(&embed_fn).await?;
166
167        let existing = self
168            .ops
169            .scroll_all(&self.collection, "key")
170            .await
171            .map_err(|e| {
172                EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
173            })?;
174
175        let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
176        for item in items {
177            current.insert(item.key().to_owned(), (item.content_hash(), item));
178        }
179
180        let model_changed = model_has_changed(&existing, embedding_model);
181
182        if model_changed {
183            tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
184            self.recreate_collection(&embed_fn).await?;
185        }
186
187        // Collect items that need embedding.
188        let mut work_items: Vec<(String, String, &T)> = Vec::new();
189        for (key, (hash, item)) in &current {
190            let needs_update = if let Some(stored) = existing.get(key) {
191                model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
192            } else {
193                true
194            };
195
196            if needs_update {
197                work_items.push((key.clone(), hash.clone(), *item));
198            } else {
199                stats.unchanged += 1;
200                self.hashes.insert(key.clone(), hash.clone());
201            }
202        }
203
204        let total = work_items.len();
205        // Clamp concurrency to at least 1: buffer_unordered(0) silently skips all futures.
206        let concurrency = self.concurrency.max(1);
207
208        // Stream results as they complete so on_progress fires in real time, not after collect.
209        let mut stream = futures::stream::iter(work_items.into_iter().map(|(key, hash, item)| {
210            let text = item.embed_text().to_owned();
211            let fut = embed_fn(&text);
212            async move { (key, hash, fut.await) }
213        }))
214        .buffer_unordered(concurrency);
215
216        let mut points_to_upsert = Vec::new();
217        let mut completed: usize = 0;
218        while let Some((key, hash, result)) = stream.next().await {
219            let vector = match result {
220                Ok(v) => v,
221                Err(e) => {
222                    tracing::warn!("failed to embed item '{key}': {e:#}");
223                    continue;
224                }
225            };
226
227            let point_id = self.point_id(&key);
228            let item = current[&key].1;
229            let mut payload = item.to_payload();
230            if let Some(obj) = payload.as_object_mut() {
231                obj.insert(
232                    "content_hash".into(),
233                    serde_json::Value::String(hash.clone()),
234                );
235                obj.insert(
236                    "embedding_model".into(),
237                    serde_json::Value::String(embedding_model.to_owned()),
238                );
239            }
240            let payload_map = QdrantOps::json_to_payload(payload)?;
241
242            points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
243
244            if existing.contains_key(&key) {
245                stats.updated += 1;
246            } else {
247                stats.added += 1;
248            }
249            self.hashes.insert(key, hash);
250
251            completed += 1;
252            if let Some(ref cb) = on_progress {
253                cb(completed, total);
254            }
255        }
256
257        if !points_to_upsert.is_empty() {
258            self.ops
259                .upsert(&self.collection, points_to_upsert)
260                .await
261                .map_err(|e| {
262                    EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
263                })?;
264        }
265
266        let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
267            .keys()
268            .filter(|key| !current.contains_key(*key))
269            .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
270            .collect();
271
272        if !orphan_ids.is_empty() {
273            stats.removed = orphan_ids.len();
274            self.ops
275                .delete_by_ids(&self.collection, orphan_ids)
276                .await
277                .map_err(|e| {
278                    EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
279                })?;
280        }
281
282        tracing::info!(
283            added = stats.added,
284            updated = stats.updated,
285            removed = stats.removed,
286            unchanged = stats.unchanged,
287            collection = &self.collection,
288            "embeddings synced"
289        );
290
291        Ok(stats)
292    }
293
294    /// Search the collection, returning raw scored Qdrant points.
295    ///
296    /// Consumers map the payloads to their domain types.
297    ///
298    /// # Errors
299    ///
300    /// Returns [`EmbeddingRegistryError`] if embedding or Qdrant search fails.
301    pub async fn search_raw(
302        &self,
303        query: &str,
304        limit: usize,
305        embed_fn: impl Fn(&str) -> EmbedFuture,
306    ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
307        let query_vec = embed_fn(query)
308            .await
309            .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
310
311        let Ok(limit_u64) = u64::try_from(limit) else {
312            return Ok(Vec::new());
313        };
314
315        let results = self
316            .ops
317            .search(&self.collection, query_vec, limit_u64, None)
318            .await
319            .map_err(|e| {
320                EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
321            })?;
322
323        let scored: Vec<crate::ScoredVectorPoint> = results
324            .into_iter()
325            .map(|point| {
326                let payload: HashMap<String, serde_json::Value> = point
327                    .payload
328                    .into_iter()
329                    .filter_map(|(k, v)| {
330                        let json_val = match v.kind? {
331                            Kind::StringValue(s) => serde_json::Value::String(s),
332                            Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
333                            Kind::BoolValue(b) => serde_json::Value::Bool(b),
334                            Kind::DoubleValue(d) => {
335                                serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
336                            }
337                            _ => return None,
338                        };
339                        Some((k, json_val))
340                    })
341                    .collect();
342
343                let id = match point.id.and_then(|pid| pid.point_id_options) {
344                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
345                    Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
346                    None => String::new(),
347                };
348
349                crate::ScoredVectorPoint {
350                    id,
351                    score: point.score,
352                    payload,
353                }
354            })
355            .collect();
356
357        Ok(scored)
358    }
359
360    fn point_id(&self, key: &str) -> String {
361        uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
362    }
363
364    async fn ensure_collection(
365        &self,
366        embed_fn: &impl Fn(&str) -> EmbedFuture,
367    ) -> Result<(), EmbeddingRegistryError> {
368        if !self
369            .ops
370            .collection_exists(&self.collection)
371            .await
372            .map_err(|e| {
373                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
374            })?
375        {
376            // Collection does not exist — probe once and create.
377            let vector_size = self.probe_vector_size(embed_fn).await?;
378            self.ops
379                .ensure_collection(&self.collection, vector_size)
380                .await
381                .map_err(|e| {
382                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
383                })?;
384            tracing::info!(
385                collection = &self.collection,
386                dimensions = vector_size,
387                "created Qdrant collection"
388            );
389            return Ok(());
390        }
391
392        let existing_size = self
393            .ops
394            .client()
395            .collection_info(&self.collection)
396            .await
397            .map_err(|e| {
398                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
399            })?
400            .result
401            .and_then(|info| info.config)
402            .and_then(|cfg| cfg.params)
403            .and_then(|params| params.vectors_config)
404            .and_then(|vc| vc.config)
405            .and_then(|cfg| match cfg {
406                qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
407                // Named-vector collections (ParamsMap) are not supported by this registry;
408                // treat size as unknown and recreate to ensure a compatible single-vector layout.
409                qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
410            });
411
412        let vector_size = self.probe_vector_size(embed_fn).await?;
413
414        if existing_size == Some(vector_size) {
415            return Ok(());
416        }
417
418        tracing::warn!(
419            collection = &self.collection,
420            existing = ?existing_size,
421            required = vector_size,
422            "vector dimension mismatch, recreating collection"
423        );
424        self.ops
425            .delete_collection(&self.collection)
426            .await
427            .map_err(|e| {
428                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
429            })?;
430        self.ops
431            .ensure_collection(&self.collection, vector_size)
432            .await
433            .map_err(|e| {
434                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
435            })?;
436        tracing::info!(
437            collection = &self.collection,
438            dimensions = vector_size,
439            "created Qdrant collection"
440        );
441
442        Ok(())
443    }
444
445    async fn probe_vector_size(
446        &self,
447        embed_fn: &impl Fn(&str) -> EmbedFuture,
448    ) -> Result<u64, EmbeddingRegistryError> {
449        let probe = embed_fn("dimension probe")
450            .await
451            .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
452        Ok(u64::try_from(probe.len())?)
453    }
454
455    async fn recreate_collection(
456        &self,
457        embed_fn: &impl Fn(&str) -> EmbedFuture,
458    ) -> Result<(), EmbeddingRegistryError> {
459        if self
460            .ops
461            .collection_exists(&self.collection)
462            .await
463            .map_err(|e| {
464                EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
465            })?
466        {
467            self.ops
468                .delete_collection(&self.collection)
469                .await
470                .map_err(|e| {
471                    EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
472                })?;
473            tracing::info!(
474                collection = &self.collection,
475                "deleted collection for recreation"
476            );
477        }
478        self.ensure_collection(embed_fn).await
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn normalize_no_suffix() {
488        assert_eq!(normalize_model_name("foo"), "foo");
489    }
490
491    #[test]
492    fn normalize_strips_latest() {
493        assert_eq!(normalize_model_name("foo:latest"), "foo");
494    }
495
496    #[test]
497    fn normalize_other_tag_unchanged() {
498        assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
499    }
500
501    struct TestItem {
502        k: String,
503        text: String,
504    }
505
506    impl Embeddable for TestItem {
507        fn key(&self) -> &str {
508            &self.k
509        }
510
511        fn content_hash(&self) -> String {
512            let mut hasher = blake3::Hasher::new();
513            hasher.update(self.text.as_bytes());
514            hasher.finalize().to_hex().to_string()
515        }
516
517        fn embed_text(&self) -> &str {
518            &self.text
519        }
520
521        fn to_payload(&self) -> serde_json::Value {
522            serde_json::json!({"key": self.k, "text": self.text})
523        }
524    }
525
526    fn make_item(k: &str, text: &str) -> TestItem {
527        TestItem {
528            k: k.into(),
529            text: text.into(),
530        }
531    }
532
533    #[test]
534    fn registry_new_valid_url() {
535        let ops = QdrantOps::new("http://localhost:6334").unwrap();
536        let ns = uuid::Uuid::from_bytes([0u8; 16]);
537        let reg = EmbeddingRegistry::new(ops, "test_col", ns);
538        let dbg = format!("{reg:?}");
539        assert!(dbg.contains("EmbeddingRegistry"));
540        assert!(dbg.contains("test_col"));
541    }
542
543    #[test]
544    fn embeddable_content_hash_deterministic() {
545        let item = make_item("key", "some text");
546        assert_eq!(item.content_hash(), item.content_hash());
547    }
548
549    #[test]
550    fn embeddable_content_hash_changes() {
551        let a = make_item("key", "text a");
552        let b = make_item("key", "text b");
553        assert_ne!(a.content_hash(), b.content_hash());
554    }
555
556    #[test]
557    fn embeddable_payload_contains_key() {
558        let item = make_item("my-key", "desc");
559        let payload = item.to_payload();
560        assert_eq!(payload["key"], "my-key");
561    }
562
563    #[test]
564    fn sync_stats_default() {
565        let s = SyncStats::default();
566        assert_eq!(s.added, 0);
567        assert_eq!(s.updated, 0);
568        assert_eq!(s.removed, 0);
569        assert_eq!(s.unchanged, 0);
570    }
571
572    #[test]
573    fn sync_stats_debug() {
574        let s = SyncStats {
575            added: 1,
576            updated: 2,
577            removed: 3,
578            unchanged: 4,
579        };
580        let dbg = format!("{s:?}");
581        assert!(dbg.contains("added"));
582    }
583
584    #[tokio::test]
585    async fn search_raw_embed_fail_returns_error() {
586        let ops = QdrantOps::new("http://localhost:6334").unwrap();
587        let ns = uuid::Uuid::from_bytes([0u8; 16]);
588        let reg = EmbeddingRegistry::new(ops, "test", ns);
589        let embed_fn = |_: &str| -> EmbedFuture {
590            Box::pin(async {
591                Err(Box::new(std::io::Error::other("fail"))
592                    as Box<dyn std::error::Error + Send + Sync>)
593            })
594        };
595        let result = reg.search_raw("query", 5, embed_fn).await;
596        assert!(result.is_err());
597    }
598
599    #[tokio::test]
600    async fn sync_with_unreachable_qdrant_fails() {
601        let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
602        let ns = uuid::Uuid::from_bytes([0u8; 16]);
603        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
604        let items = vec![make_item("k", "text")];
605        let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
606        let result = reg.sync(&items, "model", embed_fn, None).await;
607        assert!(result.is_err());
608    }
609
610    // ── model_has_changed unit tests ──────────────────────────────────────────
611
612    fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
613        let mut point = HashMap::new();
614        point.insert("embedding_model".to_owned(), model.to_owned());
615        let mut map = HashMap::new();
616        map.insert("k1".to_owned(), point);
617        map
618    }
619
620    #[test]
621    fn model_has_changed_latest_vs_bare_is_false() {
622        // Root cause of #2894: stored ":latest" suffix must not trigger recreation.
623        let existing = make_existing("nomic-embed-text-v2-moe:latest");
624        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
625    }
626
627    #[test]
628    fn model_has_changed_same_model_is_false() {
629        let existing = make_existing("nomic-embed-text-v2-moe");
630        assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
631    }
632
633    #[test]
634    fn model_has_changed_different_model_is_true() {
635        let existing = make_existing("all-minilm");
636        assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
637    }
638
639    #[test]
640    fn model_has_changed_empty_existing_is_false() {
641        assert!(!model_has_changed(&HashMap::new(), "any-model"));
642    }
643
644    // ── concurrency guard ─────────────────────────────────────────────────────
645
646    #[test]
647    fn concurrency_zero_clamped_to_one() {
648        let ops = QdrantOps::new("http://localhost:6334").unwrap();
649        let ns = uuid::Uuid::from_bytes([0u8; 16]);
650        let mut reg = EmbeddingRegistry::new(ops, "test", ns);
651        reg.concurrency = 0;
652        // Clamp is applied inside sync; verify the field itself can be set to 0
653        // and the guard converts it to 1 without panicking (tested via field value).
654        assert_eq!(reg.concurrency.max(1), 1);
655    }
656
657    // ── integration tests (require live Qdrant via testcontainers) ────────────
658
659    /// Test: `on_progress` fires once per successfully embedded item with correct counts.
660    #[tokio::test]
661    #[ignore = "requires Docker for Qdrant"]
662    async fn on_progress_called_once_per_successful_embed() {
663        use std::sync::{
664            Arc,
665            atomic::{AtomicUsize, Ordering},
666        };
667        use testcontainers::GenericImage;
668        use testcontainers::core::{ContainerPort, WaitFor};
669        use testcontainers::runners::AsyncRunner;
670
671        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
672            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
673            .with_wait_for(WaitFor::seconds(1))
674            .with_exposed_port(ContainerPort::Tcp(6334))
675            .start()
676            .await
677            .unwrap();
678        let port = container.get_host_port_ipv4(6334).await.unwrap();
679        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}")).unwrap();
680        let ns = uuid::Uuid::new_v4();
681        let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
682
683        let items = [
684            make_item("a", "alpha"),
685            make_item("b", "beta"),
686            make_item("c", "gamma"),
687        ];
688        let call_count = Arc::new(AtomicUsize::new(0));
689        let last_done = Arc::new(AtomicUsize::new(0));
690        let last_total = Arc::new(AtomicUsize::new(0));
691        let cc = Arc::clone(&call_count);
692        let ld = Arc::clone(&last_done);
693        let lt = Arc::clone(&last_total);
694
695        let embed_fn =
696            |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
697        let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
698            Some(Box::new(move |completed, total| {
699                cc.fetch_add(1, Ordering::SeqCst);
700                ld.store(completed, Ordering::SeqCst);
701                lt.store(total, Ordering::SeqCst);
702            }));
703
704        let stats = reg
705            .sync(&items, "test-model", embed_fn, on_progress)
706            .await
707            .unwrap();
708        let n = stats.added + stats.updated;
709
710        assert_eq!(
711            call_count.load(Ordering::SeqCst),
712            n,
713            "on_progress call count"
714        );
715        assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
716        assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
717    }
718
719    /// Test: when one embed fails, the batch continues and only successful items are upserted.
720    #[tokio::test]
721    #[ignore = "requires Docker for Qdrant"]
722    async fn partial_embed_failure_skips_failed_item() {
723        use testcontainers::GenericImage;
724        use testcontainers::core::{ContainerPort, WaitFor};
725        use testcontainers::runners::AsyncRunner;
726
727        let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
728            .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
729            .with_wait_for(WaitFor::seconds(1))
730            .with_exposed_port(ContainerPort::Tcp(6334))
731            .start()
732            .await
733            .unwrap();
734        let port = container.get_host_port_ipv4(6334).await.unwrap();
735        let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}")).unwrap();
736        let ns = uuid::Uuid::new_v4();
737        let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
738
739        // Item whose embed_text contains "fail" will cause the embed_fn to return Err.
740        let items = [
741            make_item("ok1", "ok text"),
742            make_item("fail", "fail text"),
743            make_item("ok2", "ok text 2"),
744        ];
745
746        let embed_fn = |text: &str| -> EmbedFuture {
747            if text.contains("fail") {
748                Box::pin(async {
749                    Err(Box::new(std::io::Error::other("injected failure"))
750                        as Box<dyn std::error::Error + Send + Sync>)
751                })
752            } else {
753                Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
754            }
755        };
756
757        // sync must return Ok — individual failures are warned and skipped.
758        let stats = reg
759            .sync(&items, "test-model", embed_fn, None)
760            .await
761            .unwrap();
762        assert_eq!(
763            stats.added, 2,
764            "two items should be upserted, failed one skipped"
765        );
766    }
767}