1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use futures::StreamExt as _;
14use qdrant_client::qdrant::{PointStruct, value::Kind};
15
16use crate::QdrantOps;
17use crate::vector_store::VectorStoreError;
18
19pub type EmbedFuture = Pin<
21 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
22>;
23
24pub trait Embeddable: Send + Sync {
30 fn key(&self) -> &str;
32
33 fn content_hash(&self) -> String;
37
38 fn embed_text(&self) -> &str;
40
41 fn to_payload(&self) -> serde_json::Value;
46}
47
48#[derive(Debug, Default, Clone)]
50pub struct SyncStats {
51 pub added: usize,
52 pub updated: usize,
53 pub removed: usize,
54 pub unchanged: usize,
55}
56
57#[derive(Debug, thiserror::Error)]
59pub enum EmbeddingRegistryError {
60 #[error("vector store error: {0}")]
61 VectorStore(#[from] VectorStoreError),
62
63 #[error("embedding error: {0}")]
64 Embedding(String),
65
66 #[error("serialization error: {0}")]
67 Serialization(String),
68
69 #[error("dimension probe failed: {0}")]
70 DimensionProbe(String),
71}
72
73impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
74 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
75 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
76 }
77}
78
79impl From<serde_json::Error> for EmbeddingRegistryError {
80 fn from(e: serde_json::Error) -> Self {
81 Self::Serialization(e.to_string())
82 }
83}
84
85impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
86 fn from(e: std::num::TryFromIntError) -> Self {
87 Self::DimensionProbe(e.to_string())
88 }
89}
90
91fn normalize_model_name(name: &str) -> &str {
93 name.strip_suffix(":latest").unwrap_or(name)
94}
95
96fn model_has_changed(
99 existing: &HashMap<String, HashMap<String, String>>,
100 config_model: &str,
101) -> bool {
102 existing.values().any(|stored| {
103 stored
104 .get("embedding_model")
105 .is_some_and(|m| normalize_model_name(m) != normalize_model_name(config_model))
106 })
107}
108
109#[derive(Clone)]
115pub struct EmbeddingRegistry {
116 ops: QdrantOps,
117 collection: String,
118 namespace: uuid::Uuid,
119 hashes: HashMap<String, String>,
120 pub concurrency: usize,
122}
123
124impl std::fmt::Debug for EmbeddingRegistry {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("EmbeddingRegistry")
127 .field("collection", &self.collection)
128 .finish_non_exhaustive()
129 }
130}
131
132impl EmbeddingRegistry {
133 #[must_use]
135 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
136 Self {
137 ops,
138 collection: collection.into(),
139 namespace,
140 hashes: HashMap::new(),
141 concurrency: 4,
142 }
143 }
144
145 #[allow(clippy::too_many_lines)]
156 pub async fn sync<T: Embeddable>(
157 &mut self,
158 items: &[T],
159 embedding_model: &str,
160 embed_fn: impl Fn(&str) -> EmbedFuture,
161 on_progress: Option<Box<dyn Fn(usize, usize) + Send>>,
162 ) -> Result<SyncStats, EmbeddingRegistryError> {
163 let mut stats = SyncStats::default();
164
165 self.ensure_collection(&embed_fn).await?;
166
167 let existing = self
168 .ops
169 .scroll_all(&self.collection, "key")
170 .await
171 .map_err(|e| {
172 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
173 })?;
174
175 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
176 for item in items {
177 current.insert(item.key().to_owned(), (item.content_hash(), item));
178 }
179
180 let model_changed = model_has_changed(&existing, embedding_model);
181
182 if model_changed {
183 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
184 self.recreate_collection(&embed_fn).await?;
185 }
186
187 let mut work_items: Vec<(String, String, &T)> = Vec::new();
189 for (key, (hash, item)) in ¤t {
190 let needs_update = if let Some(stored) = existing.get(key) {
191 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
192 } else {
193 true
194 };
195
196 if needs_update {
197 work_items.push((key.clone(), hash.clone(), *item));
198 } else {
199 stats.unchanged += 1;
200 self.hashes.insert(key.clone(), hash.clone());
201 }
202 }
203
204 let total = work_items.len();
205 let concurrency = self.concurrency.max(1);
207
208 let mut stream = futures::stream::iter(work_items.into_iter().map(|(key, hash, item)| {
210 let text = item.embed_text().to_owned();
211 let fut = embed_fn(&text);
212 async move { (key, hash, fut.await) }
213 }))
214 .buffer_unordered(concurrency);
215
216 let mut points_to_upsert = Vec::new();
217 let mut completed: usize = 0;
218 while let Some((key, hash, result)) = stream.next().await {
219 let vector = match result {
220 Ok(v) => v,
221 Err(e) => {
222 tracing::warn!("failed to embed item '{key}': {e:#}");
223 continue;
224 }
225 };
226
227 let point_id = self.point_id(&key);
228 let item = current[&key].1;
229 let mut payload = item.to_payload();
230 if let Some(obj) = payload.as_object_mut() {
231 obj.insert(
232 "content_hash".into(),
233 serde_json::Value::String(hash.clone()),
234 );
235 obj.insert(
236 "embedding_model".into(),
237 serde_json::Value::String(embedding_model.to_owned()),
238 );
239 }
240 let payload_map = QdrantOps::json_to_payload(payload)?;
241
242 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
243
244 if existing.contains_key(&key) {
245 stats.updated += 1;
246 } else {
247 stats.added += 1;
248 }
249 self.hashes.insert(key, hash);
250
251 completed += 1;
252 if let Some(ref cb) = on_progress {
253 cb(completed, total);
254 }
255 }
256
257 if !points_to_upsert.is_empty() {
258 self.ops
259 .upsert(&self.collection, points_to_upsert)
260 .await
261 .map_err(|e| {
262 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
263 })?;
264 }
265
266 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
267 .keys()
268 .filter(|key| !current.contains_key(*key))
269 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
270 .collect();
271
272 if !orphan_ids.is_empty() {
273 stats.removed = orphan_ids.len();
274 self.ops
275 .delete_by_ids(&self.collection, orphan_ids)
276 .await
277 .map_err(|e| {
278 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
279 })?;
280 }
281
282 tracing::info!(
283 added = stats.added,
284 updated = stats.updated,
285 removed = stats.removed,
286 unchanged = stats.unchanged,
287 collection = &self.collection,
288 "embeddings synced"
289 );
290
291 Ok(stats)
292 }
293
294 pub async fn search_raw(
302 &self,
303 query: &str,
304 limit: usize,
305 embed_fn: impl Fn(&str) -> EmbedFuture,
306 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
307 let query_vec = embed_fn(query)
308 .await
309 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
310
311 let Ok(limit_u64) = u64::try_from(limit) else {
312 return Ok(Vec::new());
313 };
314
315 let results = self
316 .ops
317 .search(&self.collection, query_vec, limit_u64, None)
318 .await
319 .map_err(|e| {
320 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
321 })?;
322
323 let scored: Vec<crate::ScoredVectorPoint> = results
324 .into_iter()
325 .map(|point| {
326 let payload: HashMap<String, serde_json::Value> = point
327 .payload
328 .into_iter()
329 .filter_map(|(k, v)| {
330 let json_val = match v.kind? {
331 Kind::StringValue(s) => serde_json::Value::String(s),
332 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
333 Kind::BoolValue(b) => serde_json::Value::Bool(b),
334 Kind::DoubleValue(d) => {
335 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
336 }
337 _ => return None,
338 };
339 Some((k, json_val))
340 })
341 .collect();
342
343 let id = match point.id.and_then(|pid| pid.point_id_options) {
344 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
345 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
346 None => String::new(),
347 };
348
349 crate::ScoredVectorPoint {
350 id,
351 score: point.score,
352 payload,
353 }
354 })
355 .collect();
356
357 Ok(scored)
358 }
359
360 fn point_id(&self, key: &str) -> String {
361 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
362 }
363
364 async fn ensure_collection(
365 &self,
366 embed_fn: &impl Fn(&str) -> EmbedFuture,
367 ) -> Result<(), EmbeddingRegistryError> {
368 if !self
369 .ops
370 .collection_exists(&self.collection)
371 .await
372 .map_err(|e| {
373 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
374 })?
375 {
376 let vector_size = self.probe_vector_size(embed_fn).await?;
378 self.ops
379 .ensure_collection(&self.collection, vector_size)
380 .await
381 .map_err(|e| {
382 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
383 })?;
384 tracing::info!(
385 collection = &self.collection,
386 dimensions = vector_size,
387 "created Qdrant collection"
388 );
389 return Ok(());
390 }
391
392 let existing_size = self
393 .ops
394 .client()
395 .collection_info(&self.collection)
396 .await
397 .map_err(|e| {
398 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
399 })?
400 .result
401 .and_then(|info| info.config)
402 .and_then(|cfg| cfg.params)
403 .and_then(|params| params.vectors_config)
404 .and_then(|vc| vc.config)
405 .and_then(|cfg| match cfg {
406 qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
407 qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
410 });
411
412 let vector_size = self.probe_vector_size(embed_fn).await?;
413
414 if existing_size == Some(vector_size) {
415 return Ok(());
416 }
417
418 tracing::warn!(
419 collection = &self.collection,
420 existing = ?existing_size,
421 required = vector_size,
422 "vector dimension mismatch, recreating collection"
423 );
424 self.ops
425 .delete_collection(&self.collection)
426 .await
427 .map_err(|e| {
428 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
429 })?;
430 self.ops
431 .ensure_collection(&self.collection, vector_size)
432 .await
433 .map_err(|e| {
434 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
435 })?;
436 tracing::info!(
437 collection = &self.collection,
438 dimensions = vector_size,
439 "created Qdrant collection"
440 );
441
442 Ok(())
443 }
444
445 async fn probe_vector_size(
446 &self,
447 embed_fn: &impl Fn(&str) -> EmbedFuture,
448 ) -> Result<u64, EmbeddingRegistryError> {
449 let probe = embed_fn("dimension probe")
450 .await
451 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
452 Ok(u64::try_from(probe.len())?)
453 }
454
455 async fn recreate_collection(
456 &self,
457 embed_fn: &impl Fn(&str) -> EmbedFuture,
458 ) -> Result<(), EmbeddingRegistryError> {
459 if self
460 .ops
461 .collection_exists(&self.collection)
462 .await
463 .map_err(|e| {
464 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
465 })?
466 {
467 self.ops
468 .delete_collection(&self.collection)
469 .await
470 .map_err(|e| {
471 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
472 })?;
473 tracing::info!(
474 collection = &self.collection,
475 "deleted collection for recreation"
476 );
477 }
478 self.ensure_collection(embed_fn).await
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn normalize_no_suffix() {
488 assert_eq!(normalize_model_name("foo"), "foo");
489 }
490
491 #[test]
492 fn normalize_strips_latest() {
493 assert_eq!(normalize_model_name("foo:latest"), "foo");
494 }
495
496 #[test]
497 fn normalize_other_tag_unchanged() {
498 assert_eq!(normalize_model_name("foo:v2"), "foo:v2");
499 }
500
501 struct TestItem {
502 k: String,
503 text: String,
504 }
505
506 impl Embeddable for TestItem {
507 fn key(&self) -> &str {
508 &self.k
509 }
510
511 fn content_hash(&self) -> String {
512 let mut hasher = blake3::Hasher::new();
513 hasher.update(self.text.as_bytes());
514 hasher.finalize().to_hex().to_string()
515 }
516
517 fn embed_text(&self) -> &str {
518 &self.text
519 }
520
521 fn to_payload(&self) -> serde_json::Value {
522 serde_json::json!({"key": self.k, "text": self.text})
523 }
524 }
525
526 fn make_item(k: &str, text: &str) -> TestItem {
527 TestItem {
528 k: k.into(),
529 text: text.into(),
530 }
531 }
532
533 #[test]
534 fn registry_new_valid_url() {
535 let ops = QdrantOps::new("http://localhost:6334").unwrap();
536 let ns = uuid::Uuid::from_bytes([0u8; 16]);
537 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
538 let dbg = format!("{reg:?}");
539 assert!(dbg.contains("EmbeddingRegistry"));
540 assert!(dbg.contains("test_col"));
541 }
542
543 #[test]
544 fn embeddable_content_hash_deterministic() {
545 let item = make_item("key", "some text");
546 assert_eq!(item.content_hash(), item.content_hash());
547 }
548
549 #[test]
550 fn embeddable_content_hash_changes() {
551 let a = make_item("key", "text a");
552 let b = make_item("key", "text b");
553 assert_ne!(a.content_hash(), b.content_hash());
554 }
555
556 #[test]
557 fn embeddable_payload_contains_key() {
558 let item = make_item("my-key", "desc");
559 let payload = item.to_payload();
560 assert_eq!(payload["key"], "my-key");
561 }
562
563 #[test]
564 fn sync_stats_default() {
565 let s = SyncStats::default();
566 assert_eq!(s.added, 0);
567 assert_eq!(s.updated, 0);
568 assert_eq!(s.removed, 0);
569 assert_eq!(s.unchanged, 0);
570 }
571
572 #[test]
573 fn sync_stats_debug() {
574 let s = SyncStats {
575 added: 1,
576 updated: 2,
577 removed: 3,
578 unchanged: 4,
579 };
580 let dbg = format!("{s:?}");
581 assert!(dbg.contains("added"));
582 }
583
584 #[tokio::test]
585 async fn search_raw_embed_fail_returns_error() {
586 let ops = QdrantOps::new("http://localhost:6334").unwrap();
587 let ns = uuid::Uuid::from_bytes([0u8; 16]);
588 let reg = EmbeddingRegistry::new(ops, "test", ns);
589 let embed_fn = |_: &str| -> EmbedFuture {
590 Box::pin(async {
591 Err(Box::new(std::io::Error::other("fail"))
592 as Box<dyn std::error::Error + Send + Sync>)
593 })
594 };
595 let result = reg.search_raw("query", 5, embed_fn).await;
596 assert!(result.is_err());
597 }
598
599 #[tokio::test]
600 async fn sync_with_unreachable_qdrant_fails() {
601 let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
602 let ns = uuid::Uuid::from_bytes([0u8; 16]);
603 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
604 let items = vec![make_item("k", "text")];
605 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
606 let result = reg.sync(&items, "model", embed_fn, None).await;
607 assert!(result.is_err());
608 }
609
610 fn make_existing(model: &str) -> HashMap<String, HashMap<String, String>> {
613 let mut point = HashMap::new();
614 point.insert("embedding_model".to_owned(), model.to_owned());
615 let mut map = HashMap::new();
616 map.insert("k1".to_owned(), point);
617 map
618 }
619
620 #[test]
621 fn model_has_changed_latest_vs_bare_is_false() {
622 let existing = make_existing("nomic-embed-text-v2-moe:latest");
624 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
625 }
626
627 #[test]
628 fn model_has_changed_same_model_is_false() {
629 let existing = make_existing("nomic-embed-text-v2-moe");
630 assert!(!model_has_changed(&existing, "nomic-embed-text-v2-moe"));
631 }
632
633 #[test]
634 fn model_has_changed_different_model_is_true() {
635 let existing = make_existing("all-minilm");
636 assert!(model_has_changed(&existing, "nomic-embed-text-v2-moe"));
637 }
638
639 #[test]
640 fn model_has_changed_empty_existing_is_false() {
641 assert!(!model_has_changed(&HashMap::new(), "any-model"));
642 }
643
644 #[test]
647 fn concurrency_zero_clamped_to_one() {
648 let ops = QdrantOps::new("http://localhost:6334").unwrap();
649 let ns = uuid::Uuid::from_bytes([0u8; 16]);
650 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
651 reg.concurrency = 0;
652 assert_eq!(reg.concurrency.max(1), 1);
655 }
656
657 #[tokio::test]
661 #[ignore = "requires Docker for Qdrant"]
662 async fn on_progress_called_once_per_successful_embed() {
663 use std::sync::{
664 Arc,
665 atomic::{AtomicUsize, Ordering},
666 };
667 use testcontainers::GenericImage;
668 use testcontainers::core::{ContainerPort, WaitFor};
669 use testcontainers::runners::AsyncRunner;
670
671 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
672 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
673 .with_wait_for(WaitFor::seconds(1))
674 .with_exposed_port(ContainerPort::Tcp(6334))
675 .start()
676 .await
677 .unwrap();
678 let port = container.get_host_port_ipv4(6334).await.unwrap();
679 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}")).unwrap();
680 let ns = uuid::Uuid::new_v4();
681 let mut reg = EmbeddingRegistry::new(ops, "test_progress", ns);
682
683 let items = [
684 make_item("a", "alpha"),
685 make_item("b", "beta"),
686 make_item("c", "gamma"),
687 ];
688 let call_count = Arc::new(AtomicUsize::new(0));
689 let last_done = Arc::new(AtomicUsize::new(0));
690 let last_total = Arc::new(AtomicUsize::new(0));
691 let cc = Arc::clone(&call_count);
692 let ld = Arc::clone(&last_done);
693 let lt = Arc::clone(&last_total);
694
695 let embed_fn =
696 |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) }) };
697 let on_progress: Option<Box<dyn Fn(usize, usize) + Send>> =
698 Some(Box::new(move |completed, total| {
699 cc.fetch_add(1, Ordering::SeqCst);
700 ld.store(completed, Ordering::SeqCst);
701 lt.store(total, Ordering::SeqCst);
702 }));
703
704 let stats = reg
705 .sync(&items, "test-model", embed_fn, on_progress)
706 .await
707 .unwrap();
708 let n = stats.added + stats.updated;
709
710 assert_eq!(
711 call_count.load(Ordering::SeqCst),
712 n,
713 "on_progress call count"
714 );
715 assert_eq!(last_done.load(Ordering::SeqCst), n, "last completed");
716 assert_eq!(last_total.load(Ordering::SeqCst), n, "total");
717 }
718
719 #[tokio::test]
721 #[ignore = "requires Docker for Qdrant"]
722 async fn partial_embed_failure_skips_failed_item() {
723 use testcontainers::GenericImage;
724 use testcontainers::core::{ContainerPort, WaitFor};
725 use testcontainers::runners::AsyncRunner;
726
727 let container = GenericImage::new("qdrant/qdrant", "v1.16.0")
728 .with_wait_for(WaitFor::message_on_stdout("gRPC listening"))
729 .with_wait_for(WaitFor::seconds(1))
730 .with_exposed_port(ContainerPort::Tcp(6334))
731 .start()
732 .await
733 .unwrap();
734 let port = container.get_host_port_ipv4(6334).await.unwrap();
735 let ops = QdrantOps::new(&format!("http://127.0.0.1:{port}")).unwrap();
736 let ns = uuid::Uuid::new_v4();
737 let mut reg = EmbeddingRegistry::new(ops, "test_partial", ns);
738
739 let items = [
741 make_item("ok1", "ok text"),
742 make_item("fail", "fail text"),
743 make_item("ok2", "ok text 2"),
744 ];
745
746 let embed_fn = |text: &str| -> EmbedFuture {
747 if text.contains("fail") {
748 Box::pin(async {
749 Err(Box::new(std::io::Error::other("injected failure"))
750 as Box<dyn std::error::Error + Send + Sync>)
751 })
752 } else {
753 Box::pin(async { Ok(vec![0.1_f32, 0.2, 0.3, 0.4]) })
754 }
755 };
756
757 let stats = reg
759 .sync(&items, "test-model", embed_fn, None)
760 .await
761 .unwrap();
762 assert_eq!(
763 stats.added, 2,
764 "two items should be upserted, failed one skipped"
765 );
766 }
767}