1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use tokio::sync::RwLock;
15
16use futures::StreamExt as _;
17use qdrant_client::qdrant::{PointStruct, value::Kind};
18
19use crate::QdrantOps;
20use crate::vector_store::VectorStoreError;
21
22pub type EmbedFuture = Pin<
24 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
25>;
26
27pub trait Embeddable: Send + Sync {
33 fn key(&self) -> &str;
35
36 fn content_hash(&self) -> String;
40
41 fn embed_text(&self) -> &str;
43
44 fn to_payload(&self) -> serde_json::Value;
49}
50
51#[derive(Debug, Default, Clone)]
53pub struct SyncStats {
54 pub added: usize,
55 pub updated: usize,
56 pub removed: usize,
57 pub unchanged: usize,
58}
59
60#[derive(Debug, thiserror::Error)]
62#[non_exhaustive]
63pub enum EmbeddingRegistryError {
64 #[error("vector store error: {0}")]
65 VectorStore(#[from] VectorStoreError),
66
67 #[error("embedding error: {0}")]
68 Embedding(String),
69
70 #[error("serialization error: {0}")]
71 Serialization(String),
72
73 #[error("dimension probe failed: {0}")]
74 DimensionProbe(String),
75}
76
77impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
78 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
79 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
80 }
81}
82
83impl From<serde_json::Error> for EmbeddingRegistryError {
84 fn from(e: serde_json::Error) -> Self {
85 Self::Serialization(e.to_string())
86 }
87}
88
89fn normalize_model_name(name: &str) -> &str {
91 name.strip_suffix(":latest").unwrap_or(name)
92}
93
94fn model_has_changed(
100 existing: &HashMap<String, HashMap<String, String>>,
101 config_model: &str,
102) -> bool {
103 if config_model.is_empty() {
104 return false;
105 }
106 existing
107 .values()
108 .any(|stored| match stored.get("embedding_model") {
109 Some(m) => normalize_model_name(m) != normalize_model_name(config_model),
110 None => true,
112 })
113}
114
115#[derive(Clone)]
127pub struct EmbeddingRegistry {
128 ops: QdrantOps,
129 collection: String,
130 namespace: uuid::Uuid,
131 hashes: HashMap<String, String>,
132 pub concurrency: usize,
134 cached_dim: Arc<RwLock<Option<u64>>>,
137}
138
139impl std::fmt::Debug for EmbeddingRegistry {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("EmbeddingRegistry")
142 .field("collection", &self.collection)
143 .finish_non_exhaustive()
144 }
145}
146
147impl EmbeddingRegistry {
148 #[must_use]
150 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
151 Self {
152 ops,
153 collection: collection.into(),
154 namespace,
155 hashes: HashMap::new(),
156 concurrency: 4,
157 cached_dim: Arc::new(RwLock::new(None)),
158 }
159 }
160
161 pub async fn sync<T: Embeddable>(
172 &mut self,
173 items: &[T],
174 embedding_model: &str,
175 embed_fn: impl Fn(&str) -> EmbedFuture,
176 on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
177 ) -> Result<SyncStats, EmbeddingRegistryError> {
178 let mut stats = SyncStats::default();
179
180 self.ensure_collection(&embed_fn).await?;
181
182 let existing = self
183 .ops
184 .scroll_all(&self.collection, "key")
185 .await
186 .map_err(|e| {
187 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
188 })?;
189
190 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
191 for item in items {
192 current.insert(item.key().to_owned(), (item.content_hash(), item));
193 }
194
195 let model_changed = model_has_changed(&existing, embedding_model);
196
197 if model_changed {
198 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
199 self.recreate_collection(&embed_fn).await?;
200 }
201
202 let work_items = build_work_set(
203 ¤t,
204 &existing,
205 model_changed,
206 &mut stats,
207 &mut self.hashes,
208 );
209
210 let work_with_futures: Vec<(String, String, EmbedFuture, String, serde_json::Value)> =
213 work_items
214 .into_iter()
215 .map(|(key, hash, item)| {
216 let text = item.embed_text().to_owned();
217 let fut = embed_fn(&text);
218 let point_id = self.point_id(&key);
219 let payload = item.to_payload();
220 (key, hash, fut, point_id, payload)
221 })
222 .collect();
223
224 let points_to_upsert = embed_and_collect_points(
225 work_with_futures,
226 on_progress,
227 &existing,
228 embedding_model,
229 self.concurrency,
230 &mut stats,
231 &mut self.hashes,
232 )
233 .await?;
234
235 if !points_to_upsert.is_empty() {
236 self.ops
237 .upsert(&self.collection, points_to_upsert)
238 .await
239 .map_err(|e| {
240 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
241 })?;
242 }
243
244 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
245 .keys()
246 .filter(|key| !current.contains_key(*key))
247 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
248 .collect();
249
250 if !orphan_ids.is_empty() {
251 stats.removed = orphan_ids.len();
252 self.ops
253 .delete_by_ids(&self.collection, orphan_ids)
254 .await
255 .map_err(|e| {
256 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
257 })?;
258 }
259
260 tracing::info!(
261 added = stats.added,
262 updated = stats.updated,
263 removed = stats.removed,
264 unchanged = stats.unchanged,
265 collection = &self.collection,
266 "embeddings synced"
267 );
268
269 Ok(stats)
270 }
271
272 pub async fn search_raw(
290 &self,
291 query: &str,
292 limit: usize,
293 embed_fn: impl Fn(&str) -> EmbedFuture,
294 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
295 let query_vec = embed_fn(query)
296 .await
297 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
298
299 let collection_dim: Option<u64> = *self.cached_dim.read().await;
303
304 let collection_dim = if collection_dim.is_some() {
305 collection_dim
306 } else {
307 let probed = self
309 .ops
310 .get_collection_vector_size(&self.collection)
311 .await
312 .map_err(|e| {
313 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
314 })?;
315 if let Some(d) = probed {
316 self.set_cached_dim(d).await;
317 }
318 probed
319 };
320
321 if let Some(stored_dim) = collection_dim {
322 let query_dim = query_vec.len() as u64;
324 if query_dim != stored_dim {
325 return Err(EmbeddingRegistryError::DimensionProbe(format!(
326 "query vector dimension {query_dim} does not match collection '{}' \
327 dimension {stored_dim}; re-run sync to rebuild the collection",
328 self.collection
329 )));
330 }
331 }
332
333 let Ok(limit_u64) = u64::try_from(limit) else {
334 return Ok(Vec::new());
335 };
336
337 let results = self
338 .ops
339 .search(&self.collection, query_vec, limit_u64, None)
340 .await
341 .map_err(|e| {
342 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
343 })?;
344
345 let scored: Vec<crate::ScoredVectorPoint> = results
346 .into_iter()
347 .map(|point| {
348 let payload: HashMap<String, serde_json::Value> = point
349 .payload
350 .into_iter()
351 .filter_map(|(k, v)| {
352 let json_val = match v.kind? {
353 Kind::StringValue(s) => serde_json::Value::String(s),
354 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
355 Kind::BoolValue(b) => serde_json::Value::Bool(b),
356 Kind::DoubleValue(d) => {
357 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
358 }
359 _ => return None,
360 };
361 Some((k, json_val))
362 })
363 .collect();
364
365 let id = match point.id.and_then(|pid| pid.point_id_options) {
366 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
367 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
368 None => String::new(),
369 };
370
371 crate::ScoredVectorPoint {
372 id,
373 score: point.score,
374 payload,
375 }
376 })
377 .collect();
378
379 Ok(scored)
380 }
381
382 fn point_id(&self, key: &str) -> String {
383 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
384 }
385
386 async fn ensure_collection(
387 &self,
388 embed_fn: &impl Fn(&str) -> EmbedFuture,
389 ) -> Result<(), EmbeddingRegistryError> {
390 if !self
391 .ops
392 .collection_exists(&self.collection)
393 .await
394 .map_err(|e| {
395 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
396 })?
397 {
398 let vector_size = self.probe_vector_size(embed_fn).await?;
400 self.ops
401 .ensure_collection(&self.collection, vector_size)
402 .await
403 .map_err(|e| {
404 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
405 })?;
406 tracing::info!(
407 collection = &self.collection,
408 dimensions = vector_size,
409 "created Qdrant collection"
410 );
411 self.set_cached_dim(vector_size).await;
412 return Ok(());
413 }
414
415 let existing_size = self
416 .ops
417 .client()
418 .collection_info(&self.collection)
419 .await
420 .map_err(|e| {
421 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
422 })?
423 .result
424 .and_then(|info| info.config)
425 .and_then(|cfg| cfg.params)
426 .and_then(|params| params.vectors_config)
427 .and_then(|vc| vc.config)
428 .and_then(|cfg| match cfg {
429 qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
430 qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
433 });
434
435 let vector_size = self.probe_vector_size(embed_fn).await?;
436
437 if existing_size == Some(vector_size) {
438 self.set_cached_dim(vector_size).await;
439 return Ok(());
440 }
441
442 tracing::warn!(
443 collection = &self.collection,
444 existing = ?existing_size,
445 required = vector_size,
446 "vector dimension mismatch, recreating collection"
447 );
448 self.ops
449 .delete_collection(&self.collection)
450 .await
451 .map_err(|e| {
452 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
453 })?;
454 self.ops
455 .ensure_collection(&self.collection, vector_size)
456 .await
457 .map_err(|e| {
458 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
459 })?;
460 tracing::info!(
461 collection = &self.collection,
462 dimensions = vector_size,
463 "created Qdrant collection"
464 );
465 self.set_cached_dim(vector_size).await;
466
467 Ok(())
468 }
469
470 async fn set_cached_dim(&self, dim: u64) {
472 *self.cached_dim.write().await = Some(dim);
473 }
474
475 async fn probe_vector_size(
476 &self,
477 embed_fn: &impl Fn(&str) -> EmbedFuture,
478 ) -> Result<u64, EmbeddingRegistryError> {
479 let probe = embed_fn("dimension probe")
480 .await
481 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
482 Ok(probe.len() as u64)
484 }
485
486 async fn recreate_collection(
487 &self,
488 embed_fn: &impl Fn(&str) -> EmbedFuture,
489 ) -> Result<(), EmbeddingRegistryError> {
490 if self
491 .ops
492 .collection_exists(&self.collection)
493 .await
494 .map_err(|e| {
495 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
496 })?
497 {
498 self.ops
499 .delete_collection(&self.collection)
500 .await
501 .map_err(|e| {
502 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
503 })?;
504 tracing::info!(
505 collection = &self.collection,
506 "deleted collection for recreation"
507 );
508 }
509 self.ensure_collection(embed_fn).await
510 }
511}
512
513fn build_work_set<'a, T: Embeddable>(
519 current: &HashMap<String, (String, &'a T)>,
520 existing: &HashMap<String, HashMap<String, String>>,
521 model_changed: bool,
522 stats: &mut SyncStats,
523 hashes: &mut HashMap<String, String>,
524) -> Vec<(String, String, &'a T)> {
525 let mut work_items: Vec<(String, String, &'a T)> = Vec::new();
526 for (key, (hash, item)) in current {
527 let needs_update = if let Some(stored) = existing.get(key) {
528 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
529 } else {
530 true
531 };
532
533 if needs_update {
534 work_items.push((key.clone(), hash.clone(), *item));
535 } else {
536 stats.unchanged += 1;
537 hashes.insert(key.clone(), hash.clone());
538 }
539 }
540 work_items
541}
542
543#[allow(clippy::too_many_arguments)]
556async fn embed_and_collect_points(
557 work_items: Vec<(String, String, EmbedFuture, String, serde_json::Value)>,
558 on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
559 existing: &HashMap<String, HashMap<String, String>>,
560 embedding_model: &str,
561 concurrency: usize,
562 stats: &mut SyncStats,
563 hashes: &mut HashMap<String, String>,
564) -> Result<Vec<PointStruct>, EmbeddingRegistryError> {
565 let total = work_items.len();
566 let concurrency = concurrency.max(1);
568
569 let mut stream =
571 futures::stream::iter(work_items.into_iter().map(
572 |(key, hash, fut, point_id, payload)| async move {
573 (key, hash, fut.await, point_id, payload)
574 },
575 ))
576 .buffer_unordered(concurrency);
577
578 let mut points_to_upsert = Vec::new();
579 let mut completed: usize = 0;
580 while let Some((key, hash, result, point_id, mut payload)) = stream.next().await {
581 let vector = match result {
582 Ok(v) => v,
583 Err(e) => {
584 tracing::warn!("failed to embed item '{key}': {e:#}");
585 continue;
586 }
587 };
588
589 if let Some(obj) = payload.as_object_mut() {
590 obj.insert(
591 "content_hash".into(),
592 serde_json::Value::String(hash.clone()),
593 );
594 obj.insert(
595 "embedding_model".into(),
596 serde_json::Value::String(embedding_model.to_owned()),
597 );
598 }
599 let payload_map = QdrantOps::json_to_payload(payload)?;
600
601 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
602
603 if existing.contains_key(&key) {
604 stats.updated += 1;
605 } else {
606 stats.added += 1;
607 }
608 hashes.insert(key, hash);
609
610 completed += 1;
611 if let Some(ref cb) = on_progress {
612 cb(completed, total);
613 }
614 }
615 Ok(points_to_upsert)
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn normalize_no_suffix() {
624 assert_eq!(normalize_model_name("foo"), "foo");
625 }
626
627 #[test]
628 fn normalize_strips_latest() {
629 assert_eq!(normalize_model_name("foo:latest"), "foo");
630 }
631
632 #[test]
633 fn normalize_other_tag_unchanged() {
634 assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
635 }
636
637 struct TestItem {
638 k: String,
639 text: String,
640 }
641
642 impl Embeddable for TestItem {
643 fn key(&self) -> &str {
644 &self.k
645 }
646
647 fn content_hash(&self) -> String {
648 let mut hasher = blake3::Hasher::new();
649 hasher.update(self.text.as_bytes());
650 hasher.finalize().to_hex().to_string()
651 }
652
653 fn embed_text(&self) -> &str {
654 &self.text
655 }
656
657 fn to_payload(&self) -> serde_json::Value {
658 serde_json::json!({"key": self.k, "text": self.text})
659 }
660 }
661
662 fn make_item(k: &str, text: &str) -> TestItem {
663 TestItem {
664 k: k.into(),
665 text: text.into(),
666 }
667 }
668
669 #[test]
670 fn registry_new_valid_url() {
671 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
672 let ns = uuid::Uuid::from_bytes([0u8; 16]);
673 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
674 let dbg = format!("{reg:?}");
675 assert!(dbg.contains("EmbeddingRegistry"));
676 assert!(dbg.contains("test_col"));
677 }
678
679 #[test]
680 fn embeddable_content_hash_deterministic() {
681 let item = make_item("key", "some text");
682 assert_eq!(item.content_hash(), item.content_hash());
683 }
684
685 #[test]
686 fn embeddable_content_hash_changes() {
687 let a = make_item("key", "text a");
688 let b = make_item("key", "text b");
689 assert_ne!(a.content_hash(), b.content_hash());
690 }
691
692 #[test]
693 fn embeddable_payload_contains_key() {
694 let item = make_item("my-key", "desc");
695 let payload = item.to_payload();
696 assert_eq!(payload["key"], "my-key");
697 }
698
699 #[test]
700 fn sync_stats_default() {
701 let s = SyncStats::default();
702 assert_eq!(s.added, 0);
703 assert_eq!(s.updated, 0);
704 assert_eq!(s.removed, 0);
705 assert_eq!(s.unchanged, 0);
706 }
707
708 #[test]
709 fn sync_stats_debug() {
710 let s = SyncStats {
711 added: 1,
712 updated: 2,
713 removed: 3,
714 unchanged: 4,
715 };
716 let dbg = format!("{s:?}");
717 assert!(dbg.contains("added"));
718 }
719
720 #[tokio::test]
721 async fn search_raw_embed_fail_returns_error() {
722 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
723 let ns = uuid::Uuid::from_bytes([0u8; 16]);
724 let reg = EmbeddingRegistry::new(ops, "test", ns);
725 let embed_fn = |_: &str| -> EmbedFuture {
726 Box::pin(async {
727 Err(Box::new(std::io::Error::other("fail"))
728 as Box<dyn std::error::Error + Send + Sync>)
729 })
730 };
731 let result = reg.search_raw("query", 5, embed_fn).await;
732 assert!(result.is_err());
733 }
734
735 #[tokio::test]
741 async fn search_raw_dimension_mismatch_returns_error() {
742 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
743 let ns = uuid::Uuid::from_bytes([0u8; 16]);
744 let reg = EmbeddingRegistry::new(ops, "test_dim_guard", ns);
745
746 reg.set_cached_dim(4).await;
748
749 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
751 let result = reg.search_raw("query", 5, embed_fn).await;
752 assert!(
753 matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
754 "expected DimensionProbe error on dimension mismatch, got: {result:?}"
755 );
756 }
757
758 #[tokio::test]
763 async fn search_raw_matching_dimension_passes_guard() {
764 let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap(); let ns = uuid::Uuid::from_bytes([0u8; 16]);
766 let reg = EmbeddingRegistry::new(ops, "test_dim_pass", ns);
767
768 reg.set_cached_dim(2).await;
770
771 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
773 let result = reg.search_raw("query", 5, embed_fn).await;
774 assert!(
776 !matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
777 "guard must not fire when dimensions match"
778 );
779 }
780
781 #[tokio::test]
782 async fn sync_with_unreachable_qdrant_fails() {
783 let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
784 let ns = uuid::Uuid::from_bytes([0u8; 16]);
785 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
786 let items = vec![make_item("k", "text")];
787 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
788 let result = reg.sync(&items, "model", embed_fn, None).await;
789 assert!(result.is_err());
790 }
791
792 fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
795 let mut point = HashMap::new();
796 point.insert("embedding_model".to_owned(), model.to_owned());
797 let mut map = HashMap::new();
798 map.insert("k1".to_owned(), point);
799 map
800 }
801
802 #[test]
803 fn model_has_changed_latest_vs_bare_is_false() {
804 let existing = make_existing("nomic-embed-text-v2-moe:latest");
806 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
807 }
808
809 #[test]
810 fn model_has_changed_same_model_is_false() {
811 let existing = make_existing("nomic-embed-text-v2-moe");
812 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
813 }
814
815 #[test]
816 fn model_has_changed_different_model_is_true() {
817 let existing = make_existing("all-minilm");
818 assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
819 }
820
821 #[test]
822 fn model_has_changed_empty_existing_is_false() {
823 assert!(!model_has_changed(&HashMap::new(), "any-model"));
824 }
825
826 #[test]
827 fn model_has_changed_absent_field_with_config_model_is_true() {
828 let mut point = HashMap::new();
830 point.insert("content_hash".to_owned(), "abc".to_owned());
831 let mut map = HashMap::new();
832 map.insert("k1".to_owned(), point);
833 assert!(model_has_changed(&map, "nomic-embed-text-v2-moe"));
834 }
835
836 #[test]
837 fn model_has_changed_absent_field_with_empty_config_model_is_false() {
838 let mut point = HashMap::new();
839 point.insert("content_hash".to_owned(), "abc".to_owned());
840 let mut map = HashMap::new();
841 map.insert("k1".to_owned(), point);
842 assert!(!model_has_changed(&map, ""));
843 }
844
845 #[test]
848 fn concurrency_zero_clamped_to_one() {
849 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
850 let ns = uuid::Uuid::from_bytes([0u8; 16]);
851 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
852 reg.concurrency = 0;
853 assert_eq!(reg.concurrency.max(1), 1);
856 }
857
858 #[tokio::test]
862 #[ignore = "requires Docker for Qdrant"]
863 async fn on_progress_called_once_per_successful_embed() {
864 use std::sync::{
865 Arc,
866 atomic::{AtomicUsize, Ordering},
867 };
868 use testcontainers::GenericImage;
869 use testcontainers::core::{ContainerPort, WaitFor};
870 use testcontainers::runners::AsyncRunner;
871
872 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
873 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
874 .with_wait_for(WaitFor::seconds(1))
875 .with_exposed_port(ContainerPort::Tcp(6334))
876 .start()
877 .await
878 .unwrap();
879 let port = container.get_host_port_ipv4(6334).await.unwrap();
880 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
881 let ns = uuid::Uuid::new_v4();
882 let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
883
884 let items = [
885 make_item("a", "alpha"),
886 make_item("b", "beta"),
887 make_item("c", "gamma"),
888 ];
889 let call_count = Arc::new(AtomicUsize::new(0));
890 let last_done = Arc::new(AtomicUsize::new(0));
891 let last_total = Arc::new(AtomicUsize::new(0));
892 let cc = Arc::clone(&call_count);
893 let ld = Arc::clone(&last_done);
894 let lt = Arc::clone(&last_total);
895
896 let embed_fn =
897 |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
898 let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
899 Some(Box::new(move |completed, total| {
900 cc.fetch_add(1, Ordering::SeqCst);
901 ld.store(completed, Ordering::SeqCst);
902 lt.store(total, Ordering::SeqCst);
903 }));
904
905 let stats = reg
906 .sync(&items, "test-model", embed_fn, on_progress)
907 .await
908 .unwrap();
909 let n = stats.added + stats.updated;
910
911 assert_eq!(
912 call_count.load(Ordering::SeqCst),
913 n,
914 "on_progress call count"
915 );
916 assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
917 assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
918 }
919
920 #[tokio::test]
922 #[ignore = "requires Docker for Qdrant"]
923 async fn partial_embed_failure_skips_failed_item() {
924 use testcontainers::GenericImage;
925 use testcontainers::core::{ContainerPort, WaitFor};
926 use testcontainers::runners::AsyncRunner;
927
928 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
929 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
930 .with_wait_for(WaitFor::seconds(1))
931 .with_exposed_port(ContainerPort::Tcp(6334))
932 .start()
933 .await
934 .unwrap();
935 let port = container.get_host_port_ipv4(6334).await.unwrap();
936 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
937 let ns = uuid::Uuid::new_v4();
938 let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
939
940 let items = [
942 make_item("ok1", "ok text"),
943 make_item("fail", "fail text"),
944 make_item("ok2", "ok text 2"),
945 ];
946
947 let embed_fn = |text: &str| -> EmbedFuture {
948 if text.contains("fail") {
949 Box::pin(async {
950 Err(Box::new(std::io::Error::other("injected failure"))
951 as Box<dyn std::error::Error + Send + Sync>)
952 })
953 } else {
954 Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
955 }
956 };
957
958 let stats = reg
960 .sync(&items, "test-model", embed_fn, None)
961 .await
962 .unwrap();
963 assert_eq!(
964 stats.added, 2,
965 "two items should be upserted, failed one skipped"
966 );
967 }
968
969 #[tokio::test]
975 #[ignore = "requires Docker for Qdrant"]
976 async fn search_raw_dimension_mismatch_returns_error_live() {
977 use testcontainers::GenericImage;
978 use testcontainers::core::{ContainerPort, WaitFor};
979 use testcontainers::runners::AsyncRunner;
980
981 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
982 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
983 .with_wait_for(WaitFor::seconds(1))
984 .with_exposed_port(ContainerPort::Tcp(6334))
985 .start()
986 .await
987 .unwrap();
988 let port = container.get_host_port_ipv4(6334).await.unwrap();
989 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
990 let ns = uuid::Uuid::new_v4();
991 let mut reg = EmbeddingRegistry::new(ops, "test_dim_guard_live", ns);
992
993 let items = [make_item("a", "alpha")];
995 let embed_fn_4d =
996 |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0, 0.0, 0.0]) }) };
997 reg.sync(&items, "model-4d", embed_fn_4d, None)
998 .await
999 .unwrap();
1000
1001 let embed_fn_2d = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
1003 let result = reg.search_raw("query", 5, embed_fn_2d).await;
1004 assert!(
1005 matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
1006 "dimension mismatch must return DimensionProbe error, not silent near-zero scores; got: {result:?}"
1007 );
1008 }
1009}