1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use qdrant_client::qdrant::{PointStruct, value::Kind};
14
15use crate::QdrantOps;
16use crate::vector_store::VectorStoreError;
17
18pub type EmbedFuture = Pin<
20 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
21>;
22
23pub trait Embeddable: Send + Sync {
25 fn key(&self) -> &str;
27
28 fn content_hash(&self) -> String;
30
31 fn embed_text(&self) -> &str;
33
34 fn to_payload(&self) -> serde_json::Value;
37}
38
39#[derive(Debug, Default, Clone)]
41pub struct SyncStats {
42 pub added: usize,
43 pub updated: usize,
44 pub removed: usize,
45 pub unchanged: usize,
46}
47
48#[derive(Debug, thiserror::Error)]
50pub enum EmbeddingRegistryError {
51 #[error("vector store error: {0}")]
52 VectorStore(#[from] VectorStoreError),
53
54 #[error("embedding error: {0}")]
55 Embedding(String),
56
57 #[error("serialization error: {0}")]
58 Serialization(String),
59
60 #[error("dimension probe failed: {0}")]
61 DimensionProbe(String),
62}
63
64impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
65 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
66 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
67 }
68}
69
70impl From<serde_json::Error> for EmbeddingRegistryError {
71 fn from(e: serde_json::Error) -> Self {
72 Self::Serialization(e.to_string())
73 }
74}
75
76impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
77 fn from(e: std::num::TryFromIntError) -> Self {
78 Self::DimensionProbe(e.to_string())
79 }
80}
81
82#[derive(Clone)]
88pub struct EmbeddingRegistry {
89 ops: QdrantOps,
90 collection: String,
91 namespace: uuid::Uuid,
92 hashes: HashMap<String, String>,
93}
94
95impl std::fmt::Debug for EmbeddingRegistry {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("EmbeddingRegistry")
98 .field("collection", &self.collection)
99 .finish_non_exhaustive()
100 }
101}
102
103impl EmbeddingRegistry {
104 #[must_use]
106 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
107 Self {
108 ops,
109 collection: collection.into(),
110 namespace,
111 hashes: HashMap::new(),
112 }
113 }
114
115 pub async fn sync<T: Embeddable>(
123 &mut self,
124 items: &[T],
125 embedding_model: &str,
126 embed_fn: impl Fn(&str) -> EmbedFuture,
127 ) -> Result<SyncStats, EmbeddingRegistryError> {
128 let mut stats = SyncStats::default();
129
130 self.ensure_collection(&embed_fn).await?;
131
132 let existing = self
133 .ops
134 .scroll_all(&self.collection, "key")
135 .await
136 .map_err(|e| {
137 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
138 })?;
139
140 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
141 for item in items {
142 current.insert(item.key().to_owned(), (item.content_hash(), item));
143 }
144
145 let model_changed = existing.values().any(|stored| {
146 stored
147 .get("embedding_model")
148 .is_some_and(|m| m != embedding_model)
149 });
150
151 if model_changed {
152 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
153 self.recreate_collection(&embed_fn).await?;
154 }
155
156 let mut points_to_upsert = Vec::new();
157 for (key, (hash, item)) in ¤t {
158 let needs_update = if let Some(stored) = existing.get(key) {
159 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
160 } else {
161 true
162 };
163
164 if !needs_update {
165 stats.unchanged += 1;
166 self.hashes.insert(key.clone(), hash.clone());
167 continue;
168 }
169
170 let vector = match embed_fn(item.embed_text()).await {
171 Ok(v) => v,
172 Err(e) => {
173 tracing::warn!("failed to embed item '{key}': {e:#}");
174 continue;
175 }
176 };
177
178 let point_id = self.point_id(key);
179 let mut payload = item.to_payload();
180 if let Some(obj) = payload.as_object_mut() {
181 obj.insert(
182 "content_hash".into(),
183 serde_json::Value::String(hash.clone()),
184 );
185 obj.insert(
186 "embedding_model".into(),
187 serde_json::Value::String(embedding_model.to_owned()),
188 );
189 }
190 let payload_map = QdrantOps::json_to_payload(payload)?;
191
192 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
193
194 if existing.contains_key(key) {
195 stats.updated += 1;
196 } else {
197 stats.added += 1;
198 }
199 self.hashes.insert(key.clone(), hash.clone());
200 }
201
202 if !points_to_upsert.is_empty() {
203 self.ops
204 .upsert(&self.collection, points_to_upsert)
205 .await
206 .map_err(|e| {
207 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
208 })?;
209 }
210
211 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
212 .keys()
213 .filter(|key| !current.contains_key(*key))
214 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
215 .collect();
216
217 if !orphan_ids.is_empty() {
218 stats.removed = orphan_ids.len();
219 self.ops
220 .delete_by_ids(&self.collection, orphan_ids)
221 .await
222 .map_err(|e| {
223 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
224 })?;
225 }
226
227 tracing::info!(
228 added = stats.added,
229 updated = stats.updated,
230 removed = stats.removed,
231 unchanged = stats.unchanged,
232 collection = &self.collection,
233 "embeddings synced"
234 );
235
236 Ok(stats)
237 }
238
239 pub async fn search_raw(
247 &self,
248 query: &str,
249 limit: usize,
250 embed_fn: impl Fn(&str) -> EmbedFuture,
251 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
252 let query_vec = embed_fn(query)
253 .await
254 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
255
256 let Ok(limit_u64) = u64::try_from(limit) else {
257 return Ok(Vec::new());
258 };
259
260 let results = self
261 .ops
262 .search(&self.collection, query_vec, limit_u64, None)
263 .await
264 .map_err(|e| {
265 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
266 })?;
267
268 let scored: Vec<crate::ScoredVectorPoint> = results
269 .into_iter()
270 .map(|point| {
271 let payload: HashMap<String, serde_json::Value> = point
272 .payload
273 .into_iter()
274 .filter_map(|(k, v)| {
275 let json_val = match v.kind? {
276 Kind::StringValue(s) => serde_json::Value::String(s),
277 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
278 Kind::BoolValue(b) => serde_json::Value::Bool(b),
279 Kind::DoubleValue(d) => {
280 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
281 }
282 _ => return None,
283 };
284 Some((k, json_val))
285 })
286 .collect();
287
288 let id = match point.id.and_then(|pid| pid.point_id_options) {
289 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
290 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
291 None => String::new(),
292 };
293
294 crate::ScoredVectorPoint {
295 id,
296 score: point.score,
297 payload,
298 }
299 })
300 .collect();
301
302 Ok(scored)
303 }
304
305 fn point_id(&self, key: &str) -> String {
306 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
307 }
308
309 async fn ensure_collection(
310 &self,
311 embed_fn: &impl Fn(&str) -> EmbedFuture,
312 ) -> Result<(), EmbeddingRegistryError> {
313 if !self
314 .ops
315 .collection_exists(&self.collection)
316 .await
317 .map_err(|e| {
318 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
319 })?
320 {
321 let vector_size = self.probe_vector_size(embed_fn).await?;
323 self.ops
324 .ensure_collection(&self.collection, vector_size)
325 .await
326 .map_err(|e| {
327 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
328 })?;
329 tracing::info!(
330 collection = &self.collection,
331 dimensions = vector_size,
332 "created Qdrant collection"
333 );
334 return Ok(());
335 }
336
337 let existing_size = self
338 .ops
339 .client()
340 .collection_info(&self.collection)
341 .await
342 .map_err(|e| {
343 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
344 })?
345 .result
346 .and_then(|info| info.config)
347 .and_then(|cfg| cfg.params)
348 .and_then(|params| params.vectors_config)
349 .and_then(|vc| vc.config)
350 .and_then(|cfg| match cfg {
351 qdrant_client::qdrant::vectors_config::Config::Params(vp) => Some(vp.size),
352 qdrant_client::qdrant::vectors_config::Config::ParamsMap(_) => None,
355 });
356
357 let vector_size = self.probe_vector_size(embed_fn).await?;
358
359 if existing_size == Some(vector_size) {
360 return Ok(());
361 }
362
363 tracing::warn!(
364 collection = &self.collection,
365 existing = ?existing_size,
366 required = vector_size,
367 "vector dimension mismatch, recreating collection"
368 );
369 self.ops
370 .delete_collection(&self.collection)
371 .await
372 .map_err(|e| {
373 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
374 })?;
375 self.ops
376 .ensure_collection(&self.collection, vector_size)
377 .await
378 .map_err(|e| {
379 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
380 })?;
381 tracing::info!(
382 collection = &self.collection,
383 dimensions = vector_size,
384 "created Qdrant collection"
385 );
386
387 Ok(())
388 }
389
390 async fn probe_vector_size(
391 &self,
392 embed_fn: &impl Fn(&str) -> EmbedFuture,
393 ) -> Result<u64, EmbeddingRegistryError> {
394 let probe = embed_fn("dimension probe")
395 .await
396 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
397 Ok(u64::try_from(probe.len())?)
398 }
399
400 async fn recreate_collection(
401 &self,
402 embed_fn: &impl Fn(&str) -> EmbedFuture,
403 ) -> Result<(), EmbeddingRegistryError> {
404 if self
405 .ops
406 .collection_exists(&self.collection)
407 .await
408 .map_err(|e| {
409 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
410 })?
411 {
412 self.ops
413 .delete_collection(&self.collection)
414 .await
415 .map_err(|e| {
416 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
417 })?;
418 tracing::info!(
419 collection = &self.collection,
420 "deleted collection for recreation"
421 );
422 }
423 self.ensure_collection(embed_fn).await
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 struct TestItem {
432 k: String,
433 text: String,
434 }
435
436 impl Embeddable for TestItem {
437 fn key(&self) -> &str {
438 &self.k
439 }
440
441 fn content_hash(&self) -> String {
442 let mut hasher = blake3::Hasher::new();
443 hasher.update(self.text.as_bytes());
444 hasher.finalize().to_hex().to_string()
445 }
446
447 fn embed_text(&self) -> &str {
448 &self.text
449 }
450
451 fn to_payload(&self) -> serde_json::Value {
452 serde_json::json!({"key": self.k, "text": self.text})
453 }
454 }
455
456 fn make_item(k: &str, text: &str) -> TestItem {
457 TestItem {
458 k: k.into(),
459 text: text.into(),
460 }
461 }
462
463 #[test]
464 fn registry_new_valid_url() {
465 let ops = QdrantOps::new("http://localhost:6334").unwrap();
466 let ns = uuid::Uuid::from_bytes([0u8; 16]);
467 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
468 let dbg = format!("{reg:?}");
469 assert!(dbg.contains("EmbeddingRegistry"));
470 assert!(dbg.contains("test_col"));
471 }
472
473 #[test]
474 fn embeddable_content_hash_deterministic() {
475 let item = make_item("key", "some text");
476 assert_eq!(item.content_hash(), item.content_hash());
477 }
478
479 #[test]
480 fn embeddable_content_hash_changes() {
481 let a = make_item("key", "text a");
482 let b = make_item("key", "text b");
483 assert_ne!(a.content_hash(), b.content_hash());
484 }
485
486 #[test]
487 fn embeddable_payload_contains_key() {
488 let item = make_item("my-key", "desc");
489 let payload = item.to_payload();
490 assert_eq!(payload["key"], "my-key");
491 }
492
493 #[test]
494 fn sync_stats_default() {
495 let s = SyncStats::default();
496 assert_eq!(s.added, 0);
497 assert_eq!(s.updated, 0);
498 assert_eq!(s.removed, 0);
499 assert_eq!(s.unchanged, 0);
500 }
501
502 #[test]
503 fn sync_stats_debug() {
504 let s = SyncStats {
505 added: 1,
506 updated: 2,
507 removed: 3,
508 unchanged: 4,
509 };
510 let dbg = format!("{s:?}");
511 assert!(dbg.contains("added"));
512 }
513
514 #[tokio::test]
515 async fn search_raw_embed_fail_returns_error() {
516 let ops = QdrantOps::new("http://localhost:6334").unwrap();
517 let ns = uuid::Uuid::from_bytes([0u8; 16]);
518 let reg = EmbeddingRegistry::new(ops, "test", ns);
519 let embed_fn = |_: &str| -> EmbedFuture {
520 Box::pin(async {
521 Err(Box::new(std::io::Error::other("fail"))
522 as Box<dyn std::error::Error + Send + Sync>)
523 })
524 };
525 let result = reg.search_raw("query", 5, embed_fn).await;
526 assert!(result.is_err());
527 }
528
529 #[tokio::test]
530 async fn sync_with_unreachable_qdrant_fails() {
531 let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
532 let ns = uuid::Uuid::from_bytes([0u8; 16]);
533 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
534 let items = vec![make_item("k", "text")];
535 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
536 let result = reg.sync(&items, "model", embed_fn).await;
537 assert!(result.is_err());
538 }
539}