1use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10use qdrant_client::qdrant::{PointStruct, value::Kind};
11
12use crate::QdrantOps;
13use crate::vector_store::VectorStoreError;
14
15pub type EmbedFuture = Pin<
17 Box<dyn Future<Output = Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>>> + Send>,
18>;
19
20pub trait Embeddable: Send + Sync {
22 fn key(&self) -> &str;
24
25 fn content_hash(&self) -> String;
27
28 fn embed_text(&self) -> &str;
30
31 fn to_payload(&self) -> serde_json::Value;
34}
35
36#[derive(Debug, Default, Clone)]
38pub struct SyncStats {
39 pub added: usize,
40 pub updated: usize,
41 pub removed: usize,
42 pub unchanged: usize,
43}
44
45#[derive(Debug, thiserror::Error)]
47pub enum EmbeddingRegistryError {
48 #[error("vector store error: {0}")]
49 VectorStore(#[from] VectorStoreError),
50
51 #[error("embedding error: {0}")]
52 Embedding(String),
53
54 #[error("serialization error: {0}")]
55 Serialization(String),
56
57 #[error("dimension probe failed: {0}")]
58 DimensionProbe(String),
59}
60
61impl From<Box<qdrant_client::QdrantError>> for EmbeddingRegistryError {
62 fn from(e: Box<qdrant_client::QdrantError>) -> Self {
63 Self::VectorStore(VectorStoreError::Collection(e.to_string()))
64 }
65}
66
67impl From<serde_json::Error> for EmbeddingRegistryError {
68 fn from(e: serde_json::Error) -> Self {
69 Self::Serialization(e.to_string())
70 }
71}
72
73impl From<std::num::TryFromIntError> for EmbeddingRegistryError {
74 fn from(e: std::num::TryFromIntError) -> Self {
75 Self::DimensionProbe(e.to_string())
76 }
77}
78
79pub struct EmbeddingRegistry {
85 ops: QdrantOps,
86 collection: String,
87 namespace: uuid::Uuid,
88 hashes: HashMap<String, String>,
89}
90
91impl std::fmt::Debug for EmbeddingRegistry {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("EmbeddingRegistry")
94 .field("collection", &self.collection)
95 .finish_non_exhaustive()
96 }
97}
98
99impl EmbeddingRegistry {
100 #[must_use]
102 pub fn new(ops: QdrantOps, collection: impl Into<String>, namespace: uuid::Uuid) -> Self {
103 Self {
104 ops,
105 collection: collection.into(),
106 namespace,
107 hashes: HashMap::new(),
108 }
109 }
110
111 pub async fn sync<T: Embeddable>(
119 &mut self,
120 items: &[T],
121 embedding_model: &str,
122 embed_fn: impl Fn(&str) -> EmbedFuture,
123 ) -> Result<SyncStats, EmbeddingRegistryError> {
124 let mut stats = SyncStats::default();
125
126 self.ensure_collection(&embed_fn).await?;
127
128 let existing = self
129 .ops
130 .scroll_all(&self.collection, "key")
131 .await
132 .map_err(|e| {
133 EmbeddingRegistryError::VectorStore(VectorStoreError::Scroll(e.to_string()))
134 })?;
135
136 let mut current: HashMap<String, (String, &T)> = HashMap::with_capacity(items.len());
137 for item in items {
138 current.insert(item.key().to_owned(), (item.content_hash(), item));
139 }
140
141 let model_changed = existing.values().any(|stored| {
142 stored
143 .get("embedding_model")
144 .is_some_and(|m| m != embedding_model)
145 });
146
147 if model_changed {
148 tracing::warn!("embedding model changed to '{embedding_model}', recreating collection");
149 self.recreate_collection(&embed_fn).await?;
150 }
151
152 let mut points_to_upsert = Vec::new();
153 for (key, (hash, item)) in ¤t {
154 let needs_update = if let Some(stored) = existing.get(key) {
155 model_changed || stored.get("content_hash").is_some_and(|h| h != hash)
156 } else {
157 true
158 };
159
160 if !needs_update {
161 stats.unchanged += 1;
162 self.hashes.insert(key.clone(), hash.clone());
163 continue;
164 }
165
166 let vector = match embed_fn(item.embed_text()).await {
167 Ok(v) => v,
168 Err(e) => {
169 tracing::warn!("failed to embed item '{key}': {e:#}");
170 continue;
171 }
172 };
173
174 let point_id = self.point_id(key);
175 let mut payload = item.to_payload();
176 if let Some(obj) = payload.as_object_mut() {
177 obj.insert(
178 "content_hash".into(),
179 serde_json::Value::String(hash.clone()),
180 );
181 obj.insert(
182 "embedding_model".into(),
183 serde_json::Value::String(embedding_model.to_owned()),
184 );
185 }
186 let payload_map = QdrantOps::json_to_payload(payload)?;
187
188 points_to_upsert.push(PointStruct::new(point_id, vector, payload_map));
189
190 if existing.contains_key(key) {
191 stats.updated += 1;
192 } else {
193 stats.added += 1;
194 }
195 self.hashes.insert(key.clone(), hash.clone());
196 }
197
198 if !points_to_upsert.is_empty() {
199 self.ops
200 .upsert(&self.collection, points_to_upsert)
201 .await
202 .map_err(|e| {
203 EmbeddingRegistryError::VectorStore(VectorStoreError::Upsert(e.to_string()))
204 })?;
205 }
206
207 let orphan_ids: Vec<qdrant_client::qdrant::PointId> = existing
208 .keys()
209 .filter(|key| !current.contains_key(*key))
210 .map(|key| qdrant_client::qdrant::PointId::from(self.point_id(key).as_str()))
211 .collect();
212
213 if !orphan_ids.is_empty() {
214 stats.removed = orphan_ids.len();
215 self.ops
216 .delete_by_ids(&self.collection, orphan_ids)
217 .await
218 .map_err(|e| {
219 EmbeddingRegistryError::VectorStore(VectorStoreError::Delete(e.to_string()))
220 })?;
221 }
222
223 tracing::info!(
224 added = stats.added,
225 updated = stats.updated,
226 removed = stats.removed,
227 unchanged = stats.unchanged,
228 collection = &self.collection,
229 "embeddings synced"
230 );
231
232 Ok(stats)
233 }
234
235 pub async fn search_raw(
243 &self,
244 query: &str,
245 limit: usize,
246 embed_fn: impl Fn(&str) -> EmbedFuture,
247 ) -> Result<Vec<crate::ScoredVectorPoint>, EmbeddingRegistryError> {
248 let query_vec = embed_fn(query)
249 .await
250 .map_err(|e| EmbeddingRegistryError::Embedding(e.to_string()))?;
251
252 let Ok(limit_u64) = u64::try_from(limit) else {
253 return Ok(Vec::new());
254 };
255
256 let results = self
257 .ops
258 .search(&self.collection, query_vec, limit_u64, None)
259 .await
260 .map_err(|e| {
261 EmbeddingRegistryError::VectorStore(VectorStoreError::Search(e.to_string()))
262 })?;
263
264 let scored: Vec<crate::ScoredVectorPoint> = results
265 .into_iter()
266 .map(|point| {
267 let payload: HashMap<String, serde_json::Value> = point
268 .payload
269 .into_iter()
270 .filter_map(|(k, v)| {
271 let json_val = match v.kind? {
272 Kind::StringValue(s) => serde_json::Value::String(s),
273 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
274 Kind::BoolValue(b) => serde_json::Value::Bool(b),
275 Kind::DoubleValue(d) => {
276 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
277 }
278 _ => return None,
279 };
280 Some((k, json_val))
281 })
282 .collect();
283
284 let id = match point.id.and_then(|pid| pid.point_id_options) {
285 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
286 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
287 None => String::new(),
288 };
289
290 crate::ScoredVectorPoint {
291 id,
292 score: point.score,
293 payload,
294 }
295 })
296 .collect();
297
298 Ok(scored)
299 }
300
301 fn point_id(&self, key: &str) -> String {
302 uuid::Uuid::new_v5(&self.namespace, key.as_bytes()).to_string()
303 }
304
305 async fn ensure_collection(
306 &self,
307 embed_fn: &impl Fn(&str) -> EmbedFuture,
308 ) -> Result<(), EmbeddingRegistryError> {
309 if self
310 .ops
311 .collection_exists(&self.collection)
312 .await
313 .map_err(|e| {
314 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
315 })?
316 {
317 return Ok(());
318 }
319
320 let probe = embed_fn("dimension probe")
321 .await
322 .map_err(|e| EmbeddingRegistryError::DimensionProbe(e.to_string()))?;
323 let vector_size = u64::try_from(probe.len())?;
324
325 self.ops
326 .ensure_collection(&self.collection, vector_size)
327 .await
328 .map_err(|e| {
329 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
330 })?;
331
332 tracing::info!(
333 collection = &self.collection,
334 dimensions = vector_size,
335 "created Qdrant collection"
336 );
337
338 Ok(())
339 }
340
341 async fn recreate_collection(
342 &self,
343 embed_fn: &impl Fn(&str) -> EmbedFuture,
344 ) -> Result<(), EmbeddingRegistryError> {
345 if self
346 .ops
347 .collection_exists(&self.collection)
348 .await
349 .map_err(|e| {
350 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
351 })?
352 {
353 self.ops
354 .delete_collection(&self.collection)
355 .await
356 .map_err(|e| {
357 EmbeddingRegistryError::VectorStore(VectorStoreError::Collection(e.to_string()))
358 })?;
359 tracing::info!(
360 collection = &self.collection,
361 "deleted collection for recreation"
362 );
363 }
364 self.ensure_collection(embed_fn).await
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 struct TestItem {
373 k: String,
374 text: String,
375 }
376
377 impl Embeddable for TestItem {
378 fn key(&self) -> &str {
379 &self.k
380 }
381
382 fn content_hash(&self) -> String {
383 let mut hasher = blake3::Hasher::new();
384 hasher.update(self.text.as_bytes());
385 hasher.finalize().to_hex().to_string()
386 }
387
388 fn embed_text(&self) -> &str {
389 &self.text
390 }
391
392 fn to_payload(&self) -> serde_json::Value {
393 serde_json::json!({"key": self.k, "text": self.text})
394 }
395 }
396
397 fn make_item(k: &str, text: &str) -> TestItem {
398 TestItem {
399 k: k.into(),
400 text: text.into(),
401 }
402 }
403
404 #[test]
405 fn registry_new_valid_url() {
406 let ops = QdrantOps::new("http://localhost:6334").unwrap();
407 let ns = uuid::Uuid::from_bytes([0u8; 16]);
408 let reg = EmbeddingRegistry::new(ops, "test_col", ns);
409 let dbg = format!("{reg:?}");
410 assert!(dbg.contains("EmbeddingRegistry"));
411 assert!(dbg.contains("test_col"));
412 }
413
414 #[test]
415 fn embeddable_content_hash_deterministic() {
416 let item = make_item("key", "some text");
417 assert_eq!(item.content_hash(), item.content_hash());
418 }
419
420 #[test]
421 fn embeddable_content_hash_changes() {
422 let a = make_item("key", "text a");
423 let b = make_item("key", "text b");
424 assert_ne!(a.content_hash(), b.content_hash());
425 }
426
427 #[test]
428 fn embeddable_payload_contains_key() {
429 let item = make_item("my-key", "desc");
430 let payload = item.to_payload();
431 assert_eq!(payload["key"], "my-key");
432 }
433
434 #[test]
435 fn sync_stats_default() {
436 let s = SyncStats::default();
437 assert_eq!(s.added, 0);
438 assert_eq!(s.updated, 0);
439 assert_eq!(s.removed, 0);
440 assert_eq!(s.unchanged, 0);
441 }
442
443 #[test]
444 fn sync_stats_debug() {
445 let s = SyncStats {
446 added: 1,
447 updated: 2,
448 removed: 3,
449 unchanged: 4,
450 };
451 let dbg = format!("{s:?}");
452 assert!(dbg.contains("added"));
453 }
454
455 #[tokio::test]
456 async fn search_raw_embed_fail_returns_error() {
457 let ops = QdrantOps::new("http://localhost:6334").unwrap();
458 let ns = uuid::Uuid::from_bytes([0u8; 16]);
459 let reg = EmbeddingRegistry::new(ops, "test", ns);
460 let embed_fn = |_: &str| -> EmbedFuture {
461 Box::pin(async {
462 Err(Box::new(std::io::Error::other("fail"))
463 as Box<dyn std::error::Error + Send + Sync>)
464 })
465 };
466 let result = reg.search_raw("query", 5, embed_fn).await;
467 assert!(result.is_err());
468 }
469
470 #[tokio::test]
471 async fn sync_with_unreachable_qdrant_fails() {
472 let ops = QdrantOps::new("http://127.0.0.1:1").unwrap();
473 let ns = uuid::Uuid::from_bytes([0u8; 16]);
474 let mut reg = EmbeddingRegistry::new(ops, "test", ns);
475 let items = vec![make_item("k", "text")];
476 let embed_fn = |_: &str| -> EmbedFuture { Box::pin(async { Ok(vec![0.1_f32, 0.2]) }) };
477 let result = reg.sync(&items, "model", embed_fn).await;
478 assert!(result.is_err());
479 }
480}