1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12
13use qdrant_client::qdrant::{PointStruct, value::Kind};
14
15use crate::QdrantOps;
16use crate::vector_store::VectorStoreError;
17
18pub type EmbedFuture = Pin<
20 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
21>;
22
23pub trait Embeddable: Send + Sync {
25 fn key(&self) -> &str;
27
28 fn content_hash(&self) -> String;
30
31 fn embed_text(&self) -> &str;
33
34 fn to_payload(&self) -> serde_json::Value;
37}
38
39#[derive(Debug, Default, Clone)]
41pub struct SyncStats {
42 pub added: usize,
43 pub updated: usize,
44 pub removed: usize,
45 pub unchanged: usize,
46}
47
48#[derive(Debug, thiserror::Error)]
50pub enum EmbeddingRegistryError {
51 #[error("vector store error: {0}")]
52 VectorStore(#[from] VectorStoreError),
53
54 #[error("embedding error: {0}")]
55 Embedding(String),
56
57 #[error("serialization error: {0}")]
58 Serialization(String),
59
60 #[error("dimension probe failed: {0}")]
61 DimensionProbe(String),
62}
63
64impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
65 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
66 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
67 }
68}
69
70impl From<serde_json::Error> for EmbeddingRegistryError {
71 fn from(e: serde_json::Error) -> Self {
72 Self::Serialization(e.to_string())
73 }
74}
75
76impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
77 fn from(e: std::num::TryFromIntError) -> Self {
78 Self::DimensionProbe(e.to_string())
79 }
80}
81
82pub struct EmbeddingRegistry {
88 ops: QdrantOps,
89 collection: String,
90 namespace: uuid::Uuid,
91 hashes: HashMap<String, String>,
92}
93
94impl std::fmt::Debug for EmbeddingRegistry {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("EmbeddingRegistry")
97 .field("collection", &self.collection)
98 .finish_non_exhaustive()
99 }
100}
101
102impl EmbeddingRegistry {
103 #[must_use]
105 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
106 Self {
107 ops,
108 collection: collection.into(),
109 namespace,
110 hashes: HashMap::new(),
111 }
112 }
113
114 pub async fn sync<T: Embeddable>(
122 &mut self,
123 items: &[T],
124 embedding_model: &str,
125 embed_fn: impl Fn(&str) -> EmbedFuture,
126 ) -> Result<SyncStats, EmbeddingRegistryError> {
127 let mut stats = SyncStats::default();
128
129 self.ensure_collection(&embed_fn).await?;
130
131 let existing = self
132 .ops
133 .scroll_all(&self.collection, "key")
134 .await
135 .map_err(|e| {
136 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
137 })?;
138
139 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
140 for item in items {
141 current.insert(item.key().to_owned(), (item.content_hash(), item));
142 }
143
144 let model_changed = existing.values().any(|stored| {
145 stored
146 .get("embedding_model")
147 .is_some_and(|m| m != embedding_model)
148 });
149
150 if model_changed {
151 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
152 self.recreate_collection(&embed_fn).await?;
153 }
154
155 let mut points_to_upsert = Vec::new();
156 for (key, (hash, item)) in ¤t {
157 let needs_update = if let Some(stored) = existing.get(key) {
158 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
159 } else {
160 true
161 };
162
163 if !needs_update {
164 stats.unchanged += 1;
165 self.hashes.insert(key.clone(), hash.clone());
166 continue;
167 }
168
169 let vector = match embed_fn(item.embed_text()).await {
170 Ok(v) => v,
171 Err(e) => {
172 tracing::warn!("failed to embed item '{key}': {e:#}");
173 continue;
174 }
175 };
176
177 let point_id = self.point_id(key);
178 let mut payload = item.to_payload();
179 if let Some(obj) = payload.as_object_mut() {
180 obj.insert(
181 "content_hash".into(),
182 serde_json::Value::String(hash.clone()),
183 );
184 obj.insert(
185 "embedding_model".into(),
186 serde_json::Value::String(embedding_model.to_owned()),
187 );
188 }
189 let payload_map = QdrantOps::json_to_payload(payload)?;
190
191 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
192
193 if existing.contains_key(key) {
194 stats.updated += 1;
195 } else {
196 stats.added += 1;
197 }
198 self.hashes.insert(key.clone(), hash.clone());
199 }
200
201 if !points_to_upsert.is_empty() {
202 self.ops
203 .upsert(&self.collection, points_to_upsert)
204 .await
205 .map_err(|e| {
206 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
207 })?;
208 }
209
210 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
211 .keys()
212 .filter(|key| !current.contains_key(*key))
213 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
214 .collect();
215
216 if !orphan_ids.is_empty() {
217 stats.removed = orphan_ids.len();
218 self.ops
219 .delete_by_ids(&self.collection, orphan_ids)
220 .await
221 .map_err(|e| {
222 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
223 })?;
224 }
225
226 tracing::info!(
227 added = stats.added,
228 updated = stats.updated,
229 removed = stats.removed,
230 unchanged = stats.unchanged,
231 collection = &self.collection,
232 "embeddings synced"
233 );
234
235 Ok(stats)
236 }
237
238 pub async fn search_raw(
246 &self,
247 query: &str,
248 limit: usize,
249 embed_fn: impl Fn(&str) -> EmbedFuture,
250 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
251 let query_vec = embed_fn(query)
252 .await
253 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
254
255 let Ok(limit_u64) = u64::try_from(limit) else {
256 return Ok(Vec::new());
257 };
258
259 let results = self
260 .ops
261 .search(&self.collection, query_vec, limit_u64, None)
262 .await
263 .map_err(|e| {
264 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
265 })?;
266
267 let scored: Vec<crate::ScoredVectorPoint> = results
268 .into_iter()
269 .map(|point| {
270 let payload: HashMap<String, serde_json::Value> = point
271 .payload
272 .into_iter()
273 .filter_map(|(k, v)| {
274 let json_val = match v.kind? {
275 Kind::StringValue(s) => serde_json::Value::String(s),
276 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
277 Kind::BoolValue(b) => serde_json::Value::Bool(b),
278 Kind::DoubleValue(d) => {
279 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
280 }
281 _ => return None,
282 };
283 Some((k, json_val))
284 })
285 .collect();
286
287 let id = match point.id.and_then(|pid| pid.point_id_options) {
288 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
289 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
290 None => String::new(),
291 };
292
293 crate::ScoredVectorPoint {
294 id,
295 score: point.score,
296 payload,
297 }
298 })
299 .collect();
300
301 Ok(scored)
302 }
303
304 fn point_id(&self, key: &str) -> String {
305 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
306 }
307
308 async fn ensure_collection(
309 &self,
310 embed_fn: &impl Fn(&str) -> EmbedFuture,
311 ) -> Result<(), EmbeddingRegistryError> {
312 if self
313 .ops
314 .collection_exists(&self.collection)
315 .await
316 .map_err(|e| {
317 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
318 })?
319 {
320 return Ok(());
321 }
322
323 let probe = embed_fn("dimension probe")
324 .await
325 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
326 let vector_size = u64::try_from(probe.len())?;
327
328 self.ops
329 .ensure_collection(&self.collection, vector_size)
330 .await
331 .map_err(|e| {
332 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
333 })?;
334
335 tracing::info!(
336 collection = &self.collection,
337 dimensions = vector_size,
338 "created Qdrant collection"
339 );
340
341 Ok(())
342 }
343
344 async fn recreate_collection(
345 &self,
346 embed_fn: &impl Fn(&str) -> EmbedFuture,
347 ) -> Result<(), EmbeddingRegistryError> {
348 if self
349 .ops
350 .collection_exists(&self.collection)
351 .await
352 .map_err(|e| {
353 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
354 })?
355 {
356 self.ops
357 .delete_collection(&self.collection)
358 .await
359 .map_err(|e| {
360 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
361 })?;
362 tracing::info!(
363 collection = &self.collection,
364 "deleted collection for recreation"
365 );
366 }
367 self.ensure_collection(embed_fn).await
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 struct TestItem {
376 k: String,
377 text: String,
378 }
379
380 impl Embeddable for TestItem {
381 fn key(&self) -> &str {
382 &self.k
383 }
384
385 fn content_hash(&self) -> String {
386 let mut hasher = blake3::Hasher::new();
387 hasher.update(self.text.as_bytes());
388 hasher.finalize().to_hex().to_string()
389 }
390
391 fn embed_text(&self) -> &str {
392 &self.text
393 }
394
395 fn to_payload(&self) -> serde_json::Value {
396 serde_json::json!({"key": self.k, "text": self.text})
397 }
398 }
399
400 fn make_item(k: &str, text: &str) -> TestItem {
401 TestItem {
402 k: k.into(),
403 text: text.into(),
404 }
405 }
406
407 #[test]
408 fn registry_new_valid_url() {
409 let ops = QdrantOps::new("http://localhost:6334").unwrap();
410 let ns = uuid::Uuid::from_bytes([0u8; 16]);
411 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
412 let dbg = format!("{reg:?}");
413 assert!(dbg.contains("EmbeddingRegistry"));
414 assert!(dbg.contains("test_col"));
415 }
416
417 #[test]
418 fn embeddable_content_hash_deterministic() {
419 let item = make_item("key", "some text");
420 assert_eq!(item.content_hash(), item.content_hash());
421 }
422
423 #[test]
424 fn embeddable_content_hash_changes() {
425 let a = make_item("key", "text a");
426 let b = make_item("key", "text b");
427 assert_ne!(a.content_hash(), b.content_hash());
428 }
429
430 #[test]
431 fn embeddable_payload_contains_key() {
432 let item = make_item("my-key", "desc");
433 let payload = item.to_payload();
434 assert_eq!(payload["key"], "my-key");
435 }
436
437 #[test]
438 fn sync_stats_default() {
439 let s = SyncStats::default();
440 assert_eq!(s.added, 0);
441 assert_eq!(s.updated, 0);
442 assert_eq!(s.removed, 0);
443 assert_eq!(s.unchanged, 0);
444 }
445
446 #[test]
447 fn sync_stats_debug() {
448 let s = SyncStats {
449 added: 1,
450 updated: 2,
451 removed: 3,
452 unchanged: 4,
453 };
454 let dbg = format!("{s:?}");
455 assert!(dbg.contains("added"));
456 }
457
458 #[tokio::test]
459 async fn search_raw_embed_fail_returns_error() {
460 let ops = QdrantOps::new("http://localhost:6334").unwrap();
461 let ns = uuid::Uuid::from_bytes([0u8; 16]);
462 let reg = EmbeddingRegistry::new(ops, "test", ns);
463 let embed_fn = |_: &str| -> EmbedFuture {
464 Box::pin(async {
465 Err(Box::new(std::io::Error::other("fail"))
466 as Box<dyn std::error::Error + Send + Sync>)
467 })
468 };
469 let result = reg.search_raw("query", 5, embed_fn).await;
470 assert!(result.is_err());
471 }
472
473 #[tokio::test]
474 async fn sync_with_unreachable_qdrant_fails() {
475 let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
476 let ns = uuid::Uuid::from_bytes([0u8; 16]);
477 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
478 let items = vec![make_item("k", "text")];
479 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
480 let result = reg.sync(&items, "model", embed_fn).await;
481 assert!(result.is_err());
482 }
483}