1use std::collections::HashMap;
11#[allow(unused_imports)]
12use zeph_db::sql;
13
14use zeph_db::{ActiveDialect, DbPool};
15
16use crate::vector_store::{
17 BoxFuture, FieldValue, ScoredVectorPoint, ScrollResult, ScrollWithIdsResult, VectorFilter,
18 VectorPoint, VectorStore, VectorStoreError,
19};
20
21pub struct DbVectorStore {
26 pool: DbPool,
27}
28
29pub type SqliteVectorStore = DbVectorStore;
31
32impl DbVectorStore {
33 #[must_use]
38 pub fn new(pool: DbPool) -> Self {
39 Self { pool }
40 }
41}
42
43use zeph_common::math::cosine_similarity;
44
45fn matches_filter(payload: &HashMap<String, serde_json::Value>, filter: &VectorFilter) -> bool {
46 for cond in &filter.must {
47 let Some(val) = payload.get(&cond.field) else {
48 return false;
49 };
50 let matches = match &cond.value {
51 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
52 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
53 };
54 if !matches {
55 return false;
56 }
57 }
58 for cond in &filter.must_not {
59 let Some(val) = payload.get(&cond.field) else {
60 continue;
61 };
62 let matches = match &cond.value {
63 FieldValue::Integer(i) => val.as_i64().is_some_and(|v| v == *i),
64 FieldValue::Text(t) => val.as_str().is_some_and(|v| v == t.as_str()),
65 };
66 if matches {
67 return false;
68 }
69 }
70 true
71}
72
73impl VectorStore for DbVectorStore {
74 fn ensure_collection(
75 &self,
76 collection: &str,
77 _vector_size: u64,
78 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
79 let collection = collection.to_owned();
80 Box::pin(async move {
81 let sql = format!(
82 "{} INTO vector_collections (name) VALUES (?){}",
83 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
84 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
85 );
86 zeph_db::query(&sql)
87 .bind(&collection)
88 .execute(&self.pool)
89 .await
90 .map_err(|e| VectorStoreError::Collection(e.to_string()))?;
91 Ok(())
92 })
93 }
94
95 fn collection_exists(&self, collection: &str) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
96 let collection = collection.to_owned();
97 Box::pin(async move {
98 let row: (i64,) = zeph_db::query_as(sql!(
99 "SELECT COUNT(*) FROM vector_collections WHERE name = ?"
100 ))
101 .bind(&collection)
102 .fetch_one(&self.pool)
103 .await
104 .map_err(|e| VectorStoreError::Connection(e.to_string()))?;
105 Ok(row.0 > 0)
106 })
107 }
108
109 fn delete_collection(&self, collection: &str) -> BoxFuture<'_, Result<(), VectorStoreError>> {
110 let collection = collection.to_owned();
111 Box::pin(async move {
112 zeph_db::query(sql!("DELETE FROM vector_points WHERE collection = ?"))
113 .bind(&collection)
114 .execute(&self.pool)
115 .await
116 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
117 zeph_db::query(sql!("DELETE FROM vector_collections WHERE name = ?"))
118 .bind(&collection)
119 .execute(&self.pool)
120 .await
121 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
122 Ok(())
123 })
124 }
125
126 fn upsert(
127 &self,
128 collection: &str,
129 points: Vec<VectorPoint>,
130 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
131 let collection = collection.to_owned();
132 #[cfg(feature = "profiling")]
133 let span = tracing::info_span!(
134 "memory.vector_store",
135 operation = "upsert",
136 collection = %collection
137 );
138 let fut = Box::pin(async move {
139 for point in points {
140 let vector_bytes: Vec<u8> =
141 point.vector.iter().flat_map(|f| f.to_le_bytes()).collect();
142 let payload_json = serde_json::to_string(&point.payload)
143 .map_err(|e| VectorStoreError::Serialization(e.to_string()))?;
144 zeph_db::query(
145 sql!("INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?) \
146 ON CONFLICT(collection, id) DO UPDATE SET vector = excluded.vector, payload = excluded.payload"),
147 )
148 .bind(&point.id)
149 .bind(&collection)
150 .bind(&vector_bytes)
151 .bind(&payload_json)
152 .execute(&self.pool)
153 .await
154 .map_err(|e| VectorStoreError::Upsert(e.to_string()))?;
155 }
156 Ok(())
157 });
158 #[cfg(feature = "profiling")]
159 return Box::pin(tracing::Instrument::instrument(fut, span));
160 #[cfg(not(feature = "profiling"))]
161 fut
162 }
163
164 fn search(
165 &self,
166 collection: &str,
167 vector: Vec<f32>,
168 limit: u64,
169 filter: Option<VectorFilter>,
170 ) -> BoxFuture<'_, Result<Vec<ScoredVectorPoint>, VectorStoreError>> {
171 let collection = collection.to_owned();
172 #[cfg(feature = "profiling")]
173 let span = tracing::info_span!(
174 "memory.vector_store",
175 operation = "search",
176 collection = %collection
177 );
178 let fut = Box::pin(async move {
179 let rows: Vec<(String, Vec<u8>, String)> = zeph_db::query_as(sql!(
180 "SELECT id, vector, payload FROM vector_points WHERE collection = ?"
181 ))
182 .bind(&collection)
183 .fetch_all(&self.pool)
184 .await
185 .map_err(|e| VectorStoreError::Search(e.to_string()))?;
186
187 let limit_usize = usize::try_from(limit).unwrap_or(usize::MAX);
188 let mut scored: Vec<ScoredVectorPoint> = rows
189 .into_iter()
190 .filter_map(|(id, blob, payload_str)| {
191 if blob.len() % 4 != 0 {
192 return None;
193 }
194 let stored: Vec<f32> = blob
195 .chunks_exact(4)
196 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
197 .collect();
198 let payload: HashMap<String, serde_json::Value> =
199 serde_json::from_str(&payload_str).unwrap_or_default();
200
201 if filter
202 .as_ref()
203 .is_some_and(|f| !matches_filter(&payload, f))
204 {
205 return None;
206 }
207
208 let score = cosine_similarity(&vector, &stored);
209 Some(ScoredVectorPoint { id, score, payload })
210 })
211 .collect();
212
213 scored.sort_by(|a, b| {
214 b.score
215 .partial_cmp(&a.score)
216 .unwrap_or(std::cmp::Ordering::Equal)
217 });
218 scored.truncate(limit_usize);
219 Ok(scored)
220 });
221 #[cfg(feature = "profiling")]
222 return Box::pin(tracing::Instrument::instrument(fut, span));
223 #[cfg(not(feature = "profiling"))]
224 fut
225 }
226
227 fn delete_by_ids(
228 &self,
229 collection: &str,
230 ids: Vec<String>,
231 ) -> BoxFuture<'_, Result<(), VectorStoreError>> {
232 let collection = collection.to_owned();
233 #[cfg(feature = "profiling")]
234 let span = tracing::info_span!(
235 "memory.vector_store",
236 operation = "delete",
237 collection = %collection
238 );
239 let fut = Box::pin(async move {
240 for id in ids {
241 zeph_db::query(sql!(
242 "DELETE FROM vector_points WHERE collection = ? AND id = ?"
243 ))
244 .bind(&collection)
245 .bind(&id)
246 .execute(&self.pool)
247 .await
248 .map_err(|e| VectorStoreError::Delete(e.to_string()))?;
249 }
250 Ok(())
251 });
252 #[cfg(feature = "profiling")]
253 return Box::pin(tracing::Instrument::instrument(fut, span));
254 #[cfg(not(feature = "profiling"))]
255 fut
256 }
257
258 fn scroll_all(
259 &self,
260 collection: &str,
261 key_field: &str,
262 ) -> BoxFuture<'_, Result<ScrollResult, VectorStoreError>> {
263 let collection = collection.to_owned();
264 let key_field = key_field.to_owned();
265 Box::pin(async move {
266 let rows: Vec<(String, String)> = zeph_db::query_as(sql!(
267 "SELECT id, payload FROM vector_points WHERE collection = ?"
268 ))
269 .bind(&collection)
270 .fetch_all(&self.pool)
271 .await
272 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
273
274 let mut result = ScrollResult::new();
275 for (id, payload_str) in rows {
276 let payload: HashMap<String, serde_json::Value> =
277 serde_json::from_str(&payload_str).unwrap_or_default();
278 if let Some(val) = payload.get(&key_field) {
279 let mut map = HashMap::new();
280 map.insert(
281 key_field.clone(),
282 val.as_str().unwrap_or_default().to_owned(),
283 );
284 result.insert(id, map);
285 }
286 }
287 Ok(result)
288 })
289 }
290
291 fn scroll_all_with_point_ids(
292 &self,
293 collection: &str,
294 key_field: &str,
295 ) -> BoxFuture<'_, Result<ScrollWithIdsResult, VectorStoreError>> {
296 let collection = collection.to_owned();
297 let key_field = key_field.to_owned();
298 Box::pin(async move {
299 let rows: Vec<(String, String)> = zeph_db::query_as(sql!(
300 "SELECT id, payload FROM vector_points WHERE collection = ?"
301 ))
302 .bind(&collection)
303 .fetch_all(&self.pool)
304 .await
305 .map_err(|e| VectorStoreError::Scroll(e.to_string()))?;
306
307 let mut result = Vec::new();
308 for (point_id, payload_str) in rows {
309 let payload: HashMap<String, serde_json::Value> =
310 serde_json::from_str(&payload_str).unwrap_or_default();
311 let Some(key_val) = payload.get(&key_field).and_then(|v| v.as_str()) else {
312 continue;
313 };
314 let mut fields = HashMap::new();
315 for (k, v) in &payload {
316 if let Some(s) = v.as_str() {
317 fields.insert(k.clone(), s.to_owned());
318 }
319 }
320 fields.insert(key_field.clone(), key_val.to_owned());
322 result.push((point_id, fields));
323 }
324 Ok(result)
325 })
326 }
327
328 fn health_check(&self) -> BoxFuture<'_, Result<bool, VectorStoreError>> {
329 Box::pin(async move {
330 zeph_db::query_scalar::<_, i32>(sql!("SELECT 1"))
331 .fetch_one(&self.pool)
332 .await
333 .map(|_| true)
334 .map_err(|e| VectorStoreError::Collection(e.to_string()))
335 })
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::store::DbStore;
343 use crate::vector_store::FieldCondition;
344
345 async fn setup() -> (DbVectorStore, DbStore) {
346 let store = DbStore::new(":memory:").await.unwrap();
347 let pool = store.pool().clone();
348 let vs = DbVectorStore::new(pool);
349 (vs, store)
350 }
351
352 #[tokio::test]
353 async fn ensure_and_exists() {
354 let (vs, _) = setup().await;
355 assert!(!vs.collection_exists("col1").await.unwrap());
356 vs.ensure_collection("col1", 4).await.unwrap();
357 assert!(vs.collection_exists("col1").await.unwrap());
358 vs.ensure_collection("col1", 4).await.unwrap();
360 assert!(vs.collection_exists("col1").await.unwrap());
361 }
362
363 #[tokio::test]
364 async fn delete_collection() {
365 let (vs, _) = setup().await;
366 vs.ensure_collection("col1", 4).await.unwrap();
367 vs.upsert(
368 "col1",
369 vec![VectorPoint {
370 id: "p1".into(),
371 vector: vec![1.0, 0.0, 0.0, 0.0],
372 payload: HashMap::new(),
373 }],
374 )
375 .await
376 .unwrap();
377 vs.delete_collection("col1").await.unwrap();
378 assert!(!vs.collection_exists("col1").await.unwrap());
379 }
380
381 #[tokio::test]
382 async fn upsert_and_search() {
383 let (vs, _) = setup().await;
384 vs.ensure_collection("c", 4).await.unwrap();
385 vs.upsert(
386 "c",
387 vec![
388 VectorPoint {
389 id: "a".into(),
390 vector: vec![1.0, 0.0, 0.0, 0.0],
391 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
392 },
393 VectorPoint {
394 id: "b".into(),
395 vector: vec![0.0, 1.0, 0.0, 0.0],
396 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
397 },
398 ],
399 )
400 .await
401 .unwrap();
402
403 let results = vs
404 .search("c", vec![1.0, 0.0, 0.0, 0.0], 2, None)
405 .await
406 .unwrap();
407 assert_eq!(results.len(), 2);
408 assert_eq!(results[0].id, "a");
409 assert!((results[0].score - 1.0).abs() < 1e-5);
410 }
411
412 #[tokio::test]
413 async fn search_with_filter() {
414 let (vs, _) = setup().await;
415 vs.ensure_collection("c", 4).await.unwrap();
416 vs.upsert(
417 "c",
418 vec![
419 VectorPoint {
420 id: "a".into(),
421 vector: vec![1.0, 0.0, 0.0, 0.0],
422 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
423 },
424 VectorPoint {
425 id: "b".into(),
426 vector: vec![1.0, 0.0, 0.0, 0.0],
427 payload: HashMap::from([("role".into(), serde_json::json!("assistant"))]),
428 },
429 ],
430 )
431 .await
432 .unwrap();
433
434 let filter = VectorFilter {
435 must: vec![FieldCondition {
436 field: "role".into(),
437 value: FieldValue::Text("user".into()),
438 }],
439 must_not: vec![],
440 };
441 let results = vs
442 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
443 .await
444 .unwrap();
445 assert_eq!(results.len(), 1);
446 assert_eq!(results[0].id, "a");
447 }
448
449 #[tokio::test]
450 async fn delete_by_ids() {
451 let (vs, _) = setup().await;
452 vs.ensure_collection("c", 4).await.unwrap();
453 vs.upsert(
454 "c",
455 vec![
456 VectorPoint {
457 id: "a".into(),
458 vector: vec![1.0, 0.0, 0.0, 0.0],
459 payload: HashMap::new(),
460 },
461 VectorPoint {
462 id: "b".into(),
463 vector: vec![0.0, 1.0, 0.0, 0.0],
464 payload: HashMap::new(),
465 },
466 ],
467 )
468 .await
469 .unwrap();
470 vs.delete_by_ids("c", vec!["a".into()]).await.unwrap();
471 let results = vs
472 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
473 .await
474 .unwrap();
475 assert_eq!(results.len(), 1);
476 assert_eq!(results[0].id, "b");
477 }
478
479 #[tokio::test]
480 async fn scroll_all() {
481 let (vs, _) = setup().await;
482 vs.ensure_collection("c", 4).await.unwrap();
483 vs.upsert(
484 "c",
485 vec![VectorPoint {
486 id: "p1".into(),
487 vector: vec![1.0, 0.0, 0.0, 0.0],
488 payload: HashMap::from([("text".into(), serde_json::json!("hello"))]),
489 }],
490 )
491 .await
492 .unwrap();
493 let result = vs.scroll_all("c", "text").await.unwrap();
494 assert_eq!(result.len(), 1);
495 assert_eq!(result["p1"]["text"], "hello");
496 }
497
498 #[tokio::test]
499 async fn upsert_updates_existing() {
500 let (vs, _) = setup().await;
501 vs.ensure_collection("c", 4).await.unwrap();
502 vs.upsert(
503 "c",
504 vec![VectorPoint {
505 id: "p1".into(),
506 vector: vec![1.0, 0.0, 0.0, 0.0],
507 payload: HashMap::from([("v".into(), serde_json::json!(1))]),
508 }],
509 )
510 .await
511 .unwrap();
512 vs.upsert(
513 "c",
514 vec![VectorPoint {
515 id: "p1".into(),
516 vector: vec![0.0, 1.0, 0.0, 0.0],
517 payload: HashMap::from([("v".into(), serde_json::json!(2))]),
518 }],
519 )
520 .await
521 .unwrap();
522 let results = vs
523 .search("c", vec![0.0, 1.0, 0.0, 0.0], 1, None)
524 .await
525 .unwrap();
526 assert_eq!(results.len(), 1);
527 assert!((results[0].score - 1.0).abs() < 1e-5);
528 }
529
530 #[test]
531 fn cosine_similarity_import_wired() {
532 assert!(!cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).is_nan());
534 }
535
536 #[tokio::test]
537 async fn search_with_must_not_filter() {
538 let (vs, _) = setup().await;
539 vs.ensure_collection("c", 4).await.unwrap();
540 vs.upsert(
541 "c",
542 vec![
543 VectorPoint {
544 id: "a".into(),
545 vector: vec![1.0, 0.0, 0.0, 0.0],
546 payload: HashMap::from([("role".into(), serde_json::json!("user"))]),
547 },
548 VectorPoint {
549 id: "b".into(),
550 vector: vec![1.0, 0.0, 0.0, 0.0],
551 payload: HashMap::from([("role".into(), serde_json::json!("system"))]),
552 },
553 ],
554 )
555 .await
556 .unwrap();
557
558 let filter = VectorFilter {
559 must: vec![],
560 must_not: vec![FieldCondition {
561 field: "role".into(),
562 value: FieldValue::Text("system".into()),
563 }],
564 };
565 let results = vs
566 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
567 .await
568 .unwrap();
569 assert_eq!(results.len(), 1);
570 assert_eq!(results[0].id, "a");
571 }
572
573 #[tokio::test]
574 async fn search_with_integer_filter() {
575 let (vs, _) = setup().await;
576 vs.ensure_collection("c", 4).await.unwrap();
577 vs.upsert(
578 "c",
579 vec![
580 VectorPoint {
581 id: "a".into(),
582 vector: vec![1.0, 0.0, 0.0, 0.0],
583 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
584 },
585 VectorPoint {
586 id: "b".into(),
587 vector: vec![1.0, 0.0, 0.0, 0.0],
588 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
589 },
590 ],
591 )
592 .await
593 .unwrap();
594
595 let filter = VectorFilter {
596 must: vec![FieldCondition {
597 field: "conv_id".into(),
598 value: FieldValue::Integer(1),
599 }],
600 must_not: vec![],
601 };
602 let results = vs
603 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
604 .await
605 .unwrap();
606 assert_eq!(results.len(), 1);
607 assert_eq!(results[0].id, "a");
608 }
609
610 #[tokio::test]
611 async fn search_empty_collection() {
612 let (vs, _) = setup().await;
613 vs.ensure_collection("c", 4).await.unwrap();
614 let results = vs
615 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
616 .await
617 .unwrap();
618 assert!(results.is_empty());
619 }
620
621 #[tokio::test]
622 async fn search_with_must_not_integer_filter() {
623 let (vs, _) = setup().await;
624 vs.ensure_collection("c", 4).await.unwrap();
625 vs.upsert(
626 "c",
627 vec![
628 VectorPoint {
629 id: "a".into(),
630 vector: vec![1.0, 0.0, 0.0, 0.0],
631 payload: HashMap::from([("conv_id".into(), serde_json::json!(1))]),
632 },
633 VectorPoint {
634 id: "b".into(),
635 vector: vec![1.0, 0.0, 0.0, 0.0],
636 payload: HashMap::from([("conv_id".into(), serde_json::json!(2))]),
637 },
638 ],
639 )
640 .await
641 .unwrap();
642
643 let filter = VectorFilter {
644 must: vec![],
645 must_not: vec![FieldCondition {
646 field: "conv_id".into(),
647 value: FieldValue::Integer(1),
648 }],
649 };
650 let results = vs
651 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
652 .await
653 .unwrap();
654 assert_eq!(results.len(), 1);
655 assert_eq!(results[0].id, "b");
656 }
657
658 #[tokio::test]
659 async fn search_with_combined_must_and_must_not() {
660 let (vs, _) = setup().await;
661 vs.ensure_collection("c", 4).await.unwrap();
662 vs.upsert(
663 "c",
664 vec![
665 VectorPoint {
666 id: "a".into(),
667 vector: vec![1.0, 0.0, 0.0, 0.0],
668 payload: HashMap::from([
669 ("role".into(), serde_json::json!("user")),
670 ("conv_id".into(), serde_json::json!(1)),
671 ]),
672 },
673 VectorPoint {
674 id: "b".into(),
675 vector: vec![1.0, 0.0, 0.0, 0.0],
676 payload: HashMap::from([
677 ("role".into(), serde_json::json!("user")),
678 ("conv_id".into(), serde_json::json!(2)),
679 ]),
680 },
681 VectorPoint {
682 id: "c".into(),
683 vector: vec![1.0, 0.0, 0.0, 0.0],
684 payload: HashMap::from([
685 ("role".into(), serde_json::json!("assistant")),
686 ("conv_id".into(), serde_json::json!(1)),
687 ]),
688 },
689 ],
690 )
691 .await
692 .unwrap();
693
694 let filter = VectorFilter {
695 must: vec![FieldCondition {
696 field: "role".into(),
697 value: FieldValue::Text("user".into()),
698 }],
699 must_not: vec![FieldCondition {
700 field: "conv_id".into(),
701 value: FieldValue::Integer(2),
702 }],
703 };
704 let results = vs
705 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, Some(filter))
706 .await
707 .unwrap();
708 assert_eq!(results.len(), 1);
710 assert_eq!(results[0].id, "a");
711 }
712
713 #[tokio::test]
714 async fn scroll_all_missing_key_field() {
715 let (vs, _) = setup().await;
716 vs.ensure_collection("c", 4).await.unwrap();
717 vs.upsert(
718 "c",
719 vec![VectorPoint {
720 id: "p1".into(),
721 vector: vec![1.0, 0.0, 0.0, 0.0],
722 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
723 }],
724 )
725 .await
726 .unwrap();
727 let result = vs.scroll_all("c", "text").await.unwrap();
729 assert!(
730 result.is_empty(),
731 "points without the key field must not appear in scroll result"
732 );
733 }
734
735 #[tokio::test]
736 async fn delete_by_ids_empty_and_nonexistent() {
737 let (vs, _) = setup().await;
738 vs.ensure_collection("c", 4).await.unwrap();
739 vs.upsert(
740 "c",
741 vec![VectorPoint {
742 id: "a".into(),
743 vector: vec![1.0, 0.0, 0.0, 0.0],
744 payload: HashMap::new(),
745 }],
746 )
747 .await
748 .unwrap();
749
750 vs.delete_by_ids("c", vec![]).await.unwrap();
752
753 vs.delete_by_ids("c", vec!["nonexistent".into()])
755 .await
756 .unwrap();
757
758 let results = vs
760 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
761 .await
762 .unwrap();
763 assert_eq!(results.len(), 1);
764 assert_eq!(results[0].id, "a");
765 }
766
767 #[tokio::test]
768 async fn search_corrupt_blob_skipped() {
769 let (vs, store) = setup().await;
770 vs.ensure_collection("c", 4).await.unwrap();
771
772 vs.upsert(
774 "c",
775 vec![VectorPoint {
776 id: "valid".into(),
777 vector: vec![1.0, 0.0, 0.0, 0.0],
778 payload: HashMap::new(),
779 }],
780 )
781 .await
782 .unwrap();
783
784 let corrupt_blob: Vec<u8> = vec![0xFF, 0xFE, 0xFD];
787 let payload_json = r"{}";
788 zeph_db::query(sql!(
789 "INSERT INTO vector_points (id, collection, vector, payload) VALUES (?, ?, ?, ?)"
790 ))
791 .bind("corrupt")
792 .bind("c")
793 .bind(&corrupt_blob)
794 .bind(payload_json)
795 .execute(store.pool())
796 .await
797 .unwrap();
798
799 let results = vs
801 .search("c", vec![1.0, 0.0, 0.0, 0.0], 10, None)
802 .await
803 .unwrap();
804 assert_eq!(results.len(), 1);
805 assert_eq!(results[0].id, "valid");
806 }
807
808 #[tokio::test]
809 async fn scroll_all_with_point_ids_basic() {
810 let (vs, _) = setup().await;
811 vs.ensure_collection("c", 4).await.unwrap();
812 vs.upsert(
813 "c",
814 vec![
815 VectorPoint {
816 id: "p1".into(),
817 vector: vec![1.0, 0.0, 0.0, 0.0],
818 payload: HashMap::from([
819 ("entity_id_str".into(), serde_json::json!("42")),
820 ("name".into(), serde_json::json!("alice")),
821 ]),
822 },
823 VectorPoint {
824 id: "p2".into(),
825 vector: vec![0.0, 1.0, 0.0, 0.0],
826 payload: HashMap::from([
827 ("entity_id_str".into(), serde_json::json!("99")),
828 ("name".into(), serde_json::json!("bob")),
829 ]),
830 },
831 ],
832 )
833 .await
834 .unwrap();
835
836 let result = vs
837 .scroll_all_with_point_ids("c", "entity_id_str")
838 .await
839 .unwrap();
840 assert_eq!(result.len(), 2);
841
842 let mut by_id: std::collections::BTreeMap<
844 String,
845 std::collections::HashMap<String, String>,
846 > = result.into_iter().collect();
847 let p1 = by_id.remove("p1").expect("p1 missing");
848 let p2 = by_id.remove("p2").expect("p2 missing");
849 assert_eq!(p1.get("entity_id_str").map(String::as_str), Some("42"));
850 assert_eq!(p1.get("name").map(String::as_str), Some("alice"));
851 assert_eq!(p2.get("entity_id_str").map(String::as_str), Some("99"));
852 assert_eq!(p2.get("name").map(String::as_str), Some("bob"));
853 }
854
855 #[tokio::test]
856 async fn scroll_all_with_point_ids_skips_missing_key_field() {
857 let (vs, _) = setup().await;
858 vs.ensure_collection("c", 4).await.unwrap();
859 vs.upsert(
860 "c",
861 vec![
862 VectorPoint {
863 id: "has-key".into(),
864 vector: vec![1.0, 0.0, 0.0, 0.0],
865 payload: HashMap::from([("entity_id_str".into(), serde_json::json!("7"))]),
866 },
867 VectorPoint {
868 id: "no-key".into(),
869 vector: vec![0.0, 1.0, 0.0, 0.0],
870 payload: HashMap::from([("other".into(), serde_json::json!("value"))]),
871 },
872 ],
873 )
874 .await
875 .unwrap();
876
877 let result = vs
878 .scroll_all_with_point_ids("c", "entity_id_str")
879 .await
880 .unwrap();
881 assert_eq!(result.len(), 1);
883 assert_eq!(result[0].0, "has-key");
884 assert_eq!(
885 result[0].1.get("entity_id_str").map(String::as_str),
886 Some("7")
887 );
888 }
889}