1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use tokio::sync::RwLock;
15
16use futures::StreamExt as _;
17use qdrant_client::qdrant::{PointStruct, value::Kind};
18
19use crate::QdrantOps;
20use crate::vector_store::VectorStoreError;
21
22pub type EmbedFuture = Pin<
24 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
25>;
26
27pub trait Embeddable: Send + Sync {
33 fn key(&self) -> &str;
35
36 fn content_hash(&self) -> String;
40
41 fn embed_text(&self) -> &str;
43
44 fn to_payload(&self) -> serde_json::Value;
49}
50
51#[derive(Debug, Default, Clone)]
53pub struct SyncStats {
54 pub added: usize,
55 pub updated: usize,
56 pub removed: usize,
57 pub unchanged: usize,
58}
59
60#[derive(Debug, thiserror::Error)]
62pub enum EmbeddingRegistryError {
63 #[error("vector store error: {0}")]
64 VectorStore(#[from] VectorStoreError),
65
66 #[error("embedding error: {0}")]
67 Embedding(String),
68
69 #[error("serialization error: {0}")]
70 Serialization(String),
71
72 #[error("dimension probe failed: {0}")]
73 DimensionProbe(String),
74}
75
76impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
77 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
78 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
79 }
80}
81
82impl From<serde_json::Error> for EmbeddingRegistryError {
83 fn from(e: serde_json::Error) -> Self {
84 Self::Serialization(e.to_string())
85 }
86}
87
88fn normalize_model_name(name: &str) -> &str {
90 name.strip_suffix(":latest").unwrap_or(name)
91}
92
93fn model_has_changed(
99 existing: &HashMap<String, HashMap<String, String>>,
100 config_model: &str,
101) -> bool {
102 if config_model.is_empty() {
103 return false;
104 }
105 existing
106 .values()
107 .any(|stored| match stored.get("embedding_model") {
108 Some(m) => normalize_model_name(m) != normalize_model_name(config_model),
109 None => true,
111 })
112}
113
114#[derive(Clone)]
126pub struct EmbeddingRegistry {
127 ops: QdrantOps,
128 collection: String,
129 namespace: uuid::Uuid,
130 hashes: HashMap<String, String>,
131 pub concurrency: usize,
133 cached_dim: Arc<RwLock<Option<u64>>>,
136}
137
138impl std::fmt::Debug for EmbeddingRegistry {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("EmbeddingRegistry")
141 .field("collection", &self.collection)
142 .finish_non_exhaustive()
143 }
144}
145
146impl EmbeddingRegistry {
147 #[must_use]
149 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
150 Self {
151 ops,
152 collection: collection.into(),
153 namespace,
154 hashes: HashMap::new(),
155 concurrency: 4,
156 cached_dim: Arc::new(RwLock::new(None)),
157 }
158 }
159
160 pub async fn sync<T: Embeddable>(
171 &mut self,
172 items: &[T],
173 embedding_model: &str,
174 embed_fn: impl Fn(&str) -> EmbedFuture,
175 on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
176 ) -> Result<SyncStats, EmbeddingRegistryError> {
177 let mut stats = SyncStats::default();
178
179 self.ensure_collection(&embed_fn).await?;
180
181 let existing = self
182 .ops
183 .scroll_all(&self.collection, "key")
184 .await
185 .map_err(|e| {
186 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
187 })?;
188
189 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
190 for item in items {
191 current.insert(item.key().to_owned(), (item.content_hash(), item));
192 }
193
194 let model_changed = model_has_changed(&existing, embedding_model);
195
196 if model_changed {
197 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
198 self.recreate_collection(&embed_fn).await?;
199 }
200
201 let work_items = build_work_set(
202 ¤t,
203 &existing,
204 model_changed,
205 &mut stats,
206 &mut self.hashes,
207 );
208
209 let work_with_futures: Vec<(String, String, EmbedFuture, String, serde_json::Value)> =
212 work_items
213 .into_iter()
214 .map(|(key, hash, item)| {
215 let text = item.embed_text().to_owned();
216 let fut = embed_fn(&text);
217 let point_id = self.point_id(&key);
218 let payload = item.to_payload();
219 (key, hash, fut, point_id, payload)
220 })
221 .collect();
222
223 let points_to_upsert = embed_and_collect_points(
224 work_with_futures,
225 on_progress,
226 &existing,
227 embedding_model,
228 self.concurrency,
229 &mut stats,
230 &mut self.hashes,
231 )
232 .await?;
233
234 if !points_to_upsert.is_empty() {
235 self.ops
236 .upsert(&self.collection, points_to_upsert)
237 .await
238 .map_err(|e| {
239 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
240 })?;
241 }
242
243 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
244 .keys()
245 .filter(|key| !current.contains_key(*key))
246 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
247 .collect();
248
249 if !orphan_ids.is_empty() {
250 stats.removed = orphan_ids.len();
251 self.ops
252 .delete_by_ids(&self.collection, orphan_ids)
253 .await
254 .map_err(|e| {
255 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
256 })?;
257 }
258
259 tracing::info!(
260 added = stats.added,
261 updated = stats.updated,
262 removed = stats.removed,
263 unchanged = stats.unchanged,
264 collection = &self.collection,
265 "embeddings synced"
266 );
267
268 Ok(stats)
269 }
270
271 pub async fn search_raw(
289 &self,
290 query: &str,
291 limit: usize,
292 embed_fn: impl Fn(&str) -> EmbedFuture,
293 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
294 let query_vec = embed_fn(query)
295 .await
296 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
297
298 let collection_dim: Option<u64> = *self.cached_dim.read().await;
302
303 let collection_dim = if collection_dim.is_some() {
304 collection_dim
305 } else {
306 let probed = self
308 .ops
309 .get_collection_vector_size(&self.collection)
310 .await
311 .map_err(|e| {
312 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
313 })?;
314 if let Some(d) = probed {
315 self.set_cached_dim(d).await;
316 }
317 probed
318 };
319
320 if let Some(stored_dim) = collection_dim {
321 let query_dim = query_vec.len() as u64;
323 if query_dim != stored_dim {
324 return Err(EmbeddingRegistryError::DimensionProbe(format!(
325 "query vector dimension {query_dim} does not match collection '{}' \
326 dimension {stored_dim}; re-run sync to rebuild the collection",
327 self.collection
328 )));
329 }
330 }
331
332 let Ok(limit_u64) = u64::try_from(limit) else {
333 return Ok(Vec::new());
334 };
335
336 let results = self
337 .ops
338 .search(&self.collection, query_vec, limit_u64, None)
339 .await
340 .map_err(|e| {
341 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
342 })?;
343
344 let scored: Vec<crate::ScoredVectorPoint> = results
345 .into_iter()
346 .map(|point| {
347 let payload: HashMap<String, serde_json::Value> = point
348 .payload
349 .into_iter()
350 .filter_map(|(k, v)| {
351 let json_val = match v.kind? {
352 Kind::StringValue(s) => serde_json::Value::String(s),
353 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
354 Kind::BoolValue(b) => serde_json::Value::Bool(b),
355 Kind::DoubleValue(d) => {
356 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
357 }
358 _ => return None,
359 };
360 Some((k, json_val))
361 })
362 .collect();
363
364 let id = match point.id.and_then(|pid| pid.point_id_options) {
365 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
366 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
367 None => String::new(),
368 };
369
370 crate::ScoredVectorPoint {
371 id,
372 score: point.score,
373 payload,
374 }
375 })
376 .collect();
377
378 Ok(scored)
379 }
380
381 fn point_id(&self, key: &str) -> String {
382 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
383 }
384
385 async fn ensure_collection(
386 &self,
387 embed_fn: &impl Fn(&str) -> EmbedFuture,
388 ) -> Result<(), EmbeddingRegistryError> {
389 if !self
390 .ops
391 .collection_exists(&self.collection)
392 .await
393 .map_err(|e| {
394 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
395 })?
396 {
397 let vector_size = self.probe_vector_size(embed_fn).await?;
399 self.ops
400 .ensure_collection(&self.collection, vector_size)
401 .await
402 .map_err(|e| {
403 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
404 })?;
405 tracing::info!(
406 collection = &self.collection,
407 dimensions = vector_size,
408 "created Qdrant collection"
409 );
410 self.set_cached_dim(vector_size).await;
411 return Ok(());
412 }
413
414 let existing_size = self
415 .ops
416 .client()
417 .collection_info(&self.collection)
418 .await
419 .map_err(|e| {
420 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
421 })?
422 .result
423 .and_then(|info| info.config)
424 .and_then(|cfg| cfg.params)
425 .and_then(|params| params.vectors_config)
426 .and_then(|vc| vc.config)
427 .and_then(|cfg| match cfg {
428 qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
429 qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
432 });
433
434 let vector_size = self.probe_vector_size(embed_fn).await?;
435
436 if existing_size == Some(vector_size) {
437 self.set_cached_dim(vector_size).await;
438 return Ok(());
439 }
440
441 tracing::warn!(
442 collection = &self.collection,
443 existing = ?existing_size,
444 required = vector_size,
445 "vector dimension mismatch, recreating collection"
446 );
447 self.ops
448 .delete_collection(&self.collection)
449 .await
450 .map_err(|e| {
451 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
452 })?;
453 self.ops
454 .ensure_collection(&self.collection, vector_size)
455 .await
456 .map_err(|e| {
457 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
458 })?;
459 tracing::info!(
460 collection = &self.collection,
461 dimensions = vector_size,
462 "created Qdrant collection"
463 );
464 self.set_cached_dim(vector_size).await;
465
466 Ok(())
467 }
468
469 async fn set_cached_dim(&self, dim: u64) {
471 *self.cached_dim.write().await = Some(dim);
472 }
473
474 async fn probe_vector_size(
475 &self,
476 embed_fn: &impl Fn(&str) -> EmbedFuture,
477 ) -> Result<u64, EmbeddingRegistryError> {
478 let probe = embed_fn("dimension probe")
479 .await
480 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
481 Ok(probe.len() as u64)
483 }
484
485 async fn recreate_collection(
486 &self,
487 embed_fn: &impl Fn(&str) -> EmbedFuture,
488 ) -> Result<(), EmbeddingRegistryError> {
489 if self
490 .ops
491 .collection_exists(&self.collection)
492 .await
493 .map_err(|e| {
494 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
495 })?
496 {
497 self.ops
498 .delete_collection(&self.collection)
499 .await
500 .map_err(|e| {
501 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
502 })?;
503 tracing::info!(
504 collection = &self.collection,
505 "deleted collection for recreation"
506 );
507 }
508 self.ensure_collection(embed_fn).await
509 }
510}
511
512fn build_work_set<'a, T: Embeddable>(
518 current: &HashMap<String, (String, &'a T)>,
519 existing: &HashMap<String, HashMap<String, String>>,
520 model_changed: bool,
521 stats: &mut SyncStats,
522 hashes: &mut HashMap<String, String>,
523) -> Vec<(String, String, &'a T)> {
524 let mut work_items: Vec<(String, String, &'a T)> = Vec::new();
525 for (key, (hash, item)) in current {
526 let needs_update = if let Some(stored) = existing.get(key) {
527 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
528 } else {
529 true
530 };
531
532 if needs_update {
533 work_items.push((key.clone(), hash.clone(), *item));
534 } else {
535 stats.unchanged += 1;
536 hashes.insert(key.clone(), hash.clone());
537 }
538 }
539 work_items
540}
541
542#[allow(clippy::too_many_arguments)]
555async fn embed_and_collect_points(
556 work_items: Vec<(String, String, EmbedFuture, String, serde_json::Value)>,
557 on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
558 existing: &HashMap<String, HashMap<String, String>>,
559 embedding_model: &str,
560 concurrency: usize,
561 stats: &mut SyncStats,
562 hashes: &mut HashMap<String, String>,
563) -> Result<Vec<PointStruct>, EmbeddingRegistryError> {
564 let total = work_items.len();
565 let concurrency = concurrency.max(1);
567
568 let mut stream =
570 futures::stream::iter(work_items.into_iter().map(
571 |(key, hash, fut, point_id, payload)| async move {
572 (key, hash, fut.await, point_id, payload)
573 },
574 ))
575 .buffer_unordered(concurrency);
576
577 let mut points_to_upsert = Vec::new();
578 let mut completed: usize = 0;
579 while let Some((key, hash, result, point_id, mut payload)) = stream.next().await {
580 let vector = match result {
581 Ok(v) => v,
582 Err(e) => {
583 tracing::warn!("failed to embed item '{key}': {e:#}");
584 continue;
585 }
586 };
587
588 if let Some(obj) = payload.as_object_mut() {
589 obj.insert(
590 "content_hash".into(),
591 serde_json::Value::String(hash.clone()),
592 );
593 obj.insert(
594 "embedding_model".into(),
595 serde_json::Value::String(embedding_model.to_owned()),
596 );
597 }
598 let payload_map = QdrantOps::json_to_payload(payload)?;
599
600 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
601
602 if existing.contains_key(&key) {
603 stats.updated += 1;
604 } else {
605 stats.added += 1;
606 }
607 hashes.insert(key, hash);
608
609 completed += 1;
610 if let Some(ref cb) = on_progress {
611 cb(completed, total);
612 }
613 }
614 Ok(points_to_upsert)
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn normalize_no_suffix() {
623 assert_eq!(normalize_model_name("foo"), "foo");
624 }
625
626 #[test]
627 fn normalize_strips_latest() {
628 assert_eq!(normalize_model_name("foo:latest"), "foo");
629 }
630
631 #[test]
632 fn normalize_other_tag_unchanged() {
633 assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
634 }
635
636 struct TestItem {
637 k: String,
638 text: String,
639 }
640
641 impl Embeddable for TestItem {
642 fn key(&self) -> &str {
643 &self.k
644 }
645
646 fn content_hash(&self) -> String {
647 let mut hasher = blake3::Hasher::new();
648 hasher.update(self.text.as_bytes());
649 hasher.finalize().to_hex().to_string()
650 }
651
652 fn embed_text(&self) -> &str {
653 &self.text
654 }
655
656 fn to_payload(&self) -> serde_json::Value {
657 serde_json::json!({"key": self.k, "text": self.text})
658 }
659 }
660
661 fn make_item(k: &str, text: &str) -> TestItem {
662 TestItem {
663 k: k.into(),
664 text: text.into(),
665 }
666 }
667
668 #[test]
669 fn registry_new_valid_url() {
670 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
671 let ns = uuid::Uuid::from_bytes([0u8; 16]);
672 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
673 let dbg = format!("{reg:?}");
674 assert!(dbg.contains("EmbeddingRegistry"));
675 assert!(dbg.contains("test_col"));
676 }
677
678 #[test]
679 fn embeddable_content_hash_deterministic() {
680 let item = make_item("key", "some text");
681 assert_eq!(item.content_hash(), item.content_hash());
682 }
683
684 #[test]
685 fn embeddable_content_hash_changes() {
686 let a = make_item("key", "text a");
687 let b = make_item("key", "text b");
688 assert_ne!(a.content_hash(), b.content_hash());
689 }
690
691 #[test]
692 fn embeddable_payload_contains_key() {
693 let item = make_item("my-key", "desc");
694 let payload = item.to_payload();
695 assert_eq!(payload["key"], "my-key");
696 }
697
698 #[test]
699 fn sync_stats_default() {
700 let s = SyncStats::default();
701 assert_eq!(s.added, 0);
702 assert_eq!(s.updated, 0);
703 assert_eq!(s.removed, 0);
704 assert_eq!(s.unchanged, 0);
705 }
706
707 #[test]
708 fn sync_stats_debug() {
709 let s = SyncStats {
710 added: 1,
711 updated: 2,
712 removed: 3,
713 unchanged: 4,
714 };
715 let dbg = format!("{s:?}");
716 assert!(dbg.contains("added"));
717 }
718
719 #[tokio::test]
720 async fn search_raw_embed_fail_returns_error() {
721 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
722 let ns = uuid::Uuid::from_bytes([0u8; 16]);
723 let reg = EmbeddingRegistry::new(ops, "test", ns);
724 let embed_fn = |_: &str| -> EmbedFuture {
725 Box::pin(async {
726 Err(Box::new(std::io::Error::other("fail"))
727 as Box<dyn std::error::Error + Send + Sync>)
728 })
729 };
730 let result = reg.search_raw("query", 5, embed_fn).await;
731 assert!(result.is_err());
732 }
733
734 #[tokio::test]
740 async fn search_raw_dimension_mismatch_returns_error() {
741 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
742 let ns = uuid::Uuid::from_bytes([0u8; 16]);
743 let reg = EmbeddingRegistry::new(ops, "test_dim_guard", ns);
744
745 reg.set_cached_dim(4).await;
747
748 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
750 let result = reg.search_raw("query", 5, embed_fn).await;
751 assert!(
752 matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
753 "expected DimensionProbe error on dimension mismatch, got: {result:?}"
754 );
755 }
756
757 #[tokio::test]
762 async fn search_raw_matching_dimension_passes_guard() {
763 let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap(); let ns = uuid::Uuid::from_bytes([0u8; 16]);
765 let reg = EmbeddingRegistry::new(ops, "test_dim_pass", ns);
766
767 reg.set_cached_dim(2).await;
769
770 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
772 let result = reg.search_raw("query", 5, embed_fn).await;
773 assert!(
775 !matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
776 "guard must not fire when dimensions match"
777 );
778 }
779
780 #[tokio::test]
781 async fn sync_with_unreachable_qdrant_fails() {
782 let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
783 let ns = uuid::Uuid::from_bytes([0u8; 16]);
784 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
785 let items = vec![make_item("k", "text")];
786 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
787 let result = reg.sync(&items, "model", embed_fn, None).await;
788 assert!(result.is_err());
789 }
790
791 fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
794 let mut point = HashMap::new();
795 point.insert("embedding_model".to_owned(), model.to_owned());
796 let mut map = HashMap::new();
797 map.insert("k1".to_owned(), point);
798 map
799 }
800
801 #[test]
802 fn model_has_changed_latest_vs_bare_is_false() {
803 let existing = make_existing("nomic-embed-text-v2-moe:latest");
805 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
806 }
807
808 #[test]
809 fn model_has_changed_same_model_is_false() {
810 let existing = make_existing("nomic-embed-text-v2-moe");
811 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
812 }
813
814 #[test]
815 fn model_has_changed_different_model_is_true() {
816 let existing = make_existing("all-minilm");
817 assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
818 }
819
820 #[test]
821 fn model_has_changed_empty_existing_is_false() {
822 assert!(!model_has_changed(&HashMap::new(), "any-model"));
823 }
824
825 #[test]
826 fn model_has_changed_absent_field_with_config_model_is_true() {
827 let mut point = HashMap::new();
829 point.insert("content_hash".to_owned(), "abc".to_owned());
830 let mut map = HashMap::new();
831 map.insert("k1".to_owned(), point);
832 assert!(model_has_changed(&map, "nomic-embed-text-v2-moe"));
833 }
834
835 #[test]
836 fn model_has_changed_absent_field_with_empty_config_model_is_false() {
837 let mut point = HashMap::new();
838 point.insert("content_hash".to_owned(), "abc".to_owned());
839 let mut map = HashMap::new();
840 map.insert("k1".to_owned(), point);
841 assert!(!model_has_changed(&map, ""));
842 }
843
844 #[test]
847 fn concurrency_zero_clamped_to_one() {
848 let ops = QdrantOps::new("http://localhost:6334", None).unwrap();
849 let ns = uuid::Uuid::from_bytes([0u8; 16]);
850 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
851 reg.concurrency = 0;
852 assert_eq!(reg.concurrency.max(1), 1);
855 }
856
857 #[tokio::test]
861 #[ignore = "requires Docker for Qdrant"]
862 async fn on_progress_called_once_per_successful_embed() {
863 use std::sync::{
864 Arc,
865 atomic::{AtomicUsize, Ordering},
866 };
867 use testcontainers::GenericImage;
868 use testcontainers::core::{ContainerPort, WaitFor};
869 use testcontainers::runners::AsyncRunner;
870
871 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
872 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
873 .with_wait_for(WaitFor::seconds(1))
874 .with_exposed_port(ContainerPort::Tcp(6334))
875 .start()
876 .await
877 .unwrap();
878 let port = container.get_host_port_ipv4(6334).await.unwrap();
879 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
880 let ns = uuid::Uuid::new_v4();
881 let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
882
883 let items = [
884 make_item("a", "alpha"),
885 make_item("b", "beta"),
886 make_item("c", "gamma"),
887 ];
888 let call_count = Arc::new(AtomicUsize::new(0));
889 let last_done = Arc::new(AtomicUsize::new(0));
890 let last_total = Arc::new(AtomicUsize::new(0));
891 let cc = Arc::clone(&call_count);
892 let ld = Arc::clone(&last_done);
893 let lt = Arc::clone(&last_total);
894
895 let embed_fn =
896 |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
897 let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
898 Some(Box::new(move |completed, total| {
899 cc.fetch_add(1, Ordering::SeqCst);
900 ld.store(completed, Ordering::SeqCst);
901 lt.store(total, Ordering::SeqCst);
902 }));
903
904 let stats = reg
905 .sync(&items, "test-model", embed_fn, on_progress)
906 .await
907 .unwrap();
908 let n = stats.added + stats.updated;
909
910 assert_eq!(
911 call_count.load(Ordering::SeqCst),
912 n,
913 "on_progress call count"
914 );
915 assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
916 assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
917 }
918
919 #[tokio::test]
921 #[ignore = "requires Docker for Qdrant"]
922 async fn partial_embed_failure_skips_failed_item() {
923 use testcontainers::GenericImage;
924 use testcontainers::core::{ContainerPort, WaitFor};
925 use testcontainers::runners::AsyncRunner;
926
927 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
928 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
929 .with_wait_for(WaitFor::seconds(1))
930 .with_exposed_port(ContainerPort::Tcp(6334))
931 .start()
932 .await
933 .unwrap();
934 let port = container.get_host_port_ipv4(6334).await.unwrap();
935 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
936 let ns = uuid::Uuid::new_v4();
937 let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
938
939 let items = [
941 make_item("ok1", "ok text"),
942 make_item("fail", "fail text"),
943 make_item("ok2", "ok text 2"),
944 ];
945
946 let embed_fn = |text: &str| -> EmbedFuture {
947 if text.contains("fail") {
948 Box::pin(async {
949 Err(Box::new(std::io::Error::other("injected failure"))
950 as Box<dyn std::error::Error + Send + Sync>)
951 })
952 } else {
953 Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
954 }
955 };
956
957 let stats = reg
959 .sync(&items, "test-model", embed_fn, None)
960 .await
961 .unwrap();
962 assert_eq!(
963 stats.added, 2,
964 "two items should be upserted, failed one skipped"
965 );
966 }
967
968 #[tokio::test]
974 #[ignore = "requires Docker for Qdrant"]
975 async fn search_raw_dimension_mismatch_returns_error_live() {
976 use testcontainers::GenericImage;
977 use testcontainers::core::{ContainerPort, WaitFor};
978 use testcontainers::runners::AsyncRunner;
979
980 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
981 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
982 .with_wait_for(WaitFor::seconds(1))
983 .with_exposed_port(ContainerPort::Tcp(6334))
984 .start()
985 .await
986 .unwrap();
987 let port = container.get_host_port_ipv4(6334).await.unwrap();
988 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}"), None).unwrap();
989 let ns = uuid::Uuid::new_v4();
990 let mut reg = EmbeddingRegistry::new(ops, "test_dim_guard_live", ns);
991
992 let items = [make_item("a", "alpha")];
994 let embed_fn_4d =
995 |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0, 0.0, 0.0]) }) };
996 reg.sync(&items, "model-4d", embed_fn_4d, None)
997 .await
998 .unwrap();
999
1000 let embed_fn_2d = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![1.0_f32, 0.0]) }) };
1002 let result = reg.search_raw("query", 5, embed_fn_2d).await;
1003 assert!(
1004 matches!(result, Err(EmbeddingRegistryError::DimensionProbe(_))),
1005 "dimension mismatch must return DimensionProbe error, not silent near-zero scores; got: {result:?}"
1006 );
1007 }
1008}