1pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[non_exhaustive]
30pub enum MessageKind {
31 Regular,
33 Summary,
35}
36
37impl MessageKind {
38 #[must_use]
39 pub fn is_summary(self) -> bool {
40 matches!(self, Self::Summary)
41 }
42}
43
44const COLLECTION_NAME: &str = "zeph_conversations";
45
46pub async fn ensure_qdrant_collection(
54 ops: &QdrantOps,
55 collection: &str,
56 vector_size: u64,
57) -> Result<(), Box<qdrant_client::QdrantError>> {
58 ops.ensure_collection(collection, vector_size).await
59}
60
61pub struct EmbeddingStore {
66 ops: Box<dyn VectorStore>,
67 collection: String,
68 pool: DbPool,
69}
70
71impl std::fmt::Debug for EmbeddingStore {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("EmbeddingStore")
74 .field("collection", &self.collection)
75 .finish_non_exhaustive()
76 }
77}
78
79#[derive(Debug)]
81pub struct SearchFilter {
82 pub conversation_id: Option<ConversationId>,
84 pub role: Option<String>,
86 pub category: Option<String>,
89}
90
91#[derive(Debug)]
93pub struct SearchResult {
94 pub message_id: MessageId,
96 pub conversation_id: ConversationId,
98 pub score: f32,
100}
101
102impl EmbeddingStore {
103 pub fn new(url: &str, api_key: Option<&str>, pool: DbPool) -> Result<Self, MemoryError> {
113 let ops = QdrantOps::new(url, api_key).map_err(MemoryError::Qdrant)?;
114
115 Ok(Self {
116 ops: Box::new(ops),
117 collection: COLLECTION_NAME.into(),
118 pool,
119 })
120 }
121
122 #[must_use]
126 pub fn new_sqlite(pool: DbPool) -> Self {
127 let ops = DbVectorStore::new(pool.clone());
128 Self {
129 ops: Box::new(ops),
130 collection: COLLECTION_NAME.into(),
131 pool,
132 }
133 }
134
135 #[must_use]
140 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
141 Self {
142 ops: store,
143 collection: COLLECTION_NAME.into(),
144 pool,
145 }
146 }
147
148 pub async fn health_check(&self) -> bool {
150 self.ops.health_check().await.unwrap_or(false)
151 }
152
153 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
161 self.ops
162 .ensure_collection(&self.collection, vector_size)
163 .await?;
164 self.ops
167 .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
168 .await?;
169 Ok(())
170 }
171
172 #[allow(clippy::too_many_arguments)] pub async fn store_with_tool_context(
183 &self,
184 message_id: MessageId,
185 conversation_id: ConversationId,
186 role: &str,
187 vector: Vec<f32>,
188 kind: MessageKind,
189 model: &str,
190 chunk_index: u32,
191 tool_name: &str,
192 exit_code: Option<i32>,
193 timestamp: Option<&str>,
194 ) -> Result<String, MemoryError> {
195 let point_id = uuid::Uuid::new_v4().to_string();
196 let dimensions = i64::try_from(vector.len())?;
197
198 let mut payload = std::collections::HashMap::from([
199 ("message_id".to_owned(), serde_json::json!(message_id.0)),
200 (
201 "conversation_id".to_owned(),
202 serde_json::json!(conversation_id.0),
203 ),
204 ("role".to_owned(), serde_json::json!(role)),
205 (
206 "is_summary".to_owned(),
207 serde_json::json!(kind.is_summary()),
208 ),
209 ("tool_name".to_owned(), serde_json::json!(tool_name)),
210 ]);
211 if let Some(code) = exit_code {
212 payload.insert("exit_code".to_owned(), serde_json::json!(code));
213 }
214 if let Some(ts) = timestamp {
215 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
216 }
217
218 let point = VectorPoint {
219 id: point_id.clone(),
220 vector,
221 payload,
222 };
223
224 self.ops.upsert(&self.collection, vec![point]).await?;
225
226 let chunk_index_i64 = i64::from(chunk_index);
227 zeph_db::query(sql!(
228 "INSERT INTO embeddings_metadata \
229 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
230 VALUES (?, ?, ?, ?, ?) \
231 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
232 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
233 ))
234 .bind(message_id)
235 .bind(chunk_index_i64)
236 .bind(&point_id)
237 .bind(dimensions)
238 .bind(model)
239 .execute(&self.pool)
240 .await?;
241
242 Ok(point_id)
243 }
244
245 #[allow(clippy::too_many_arguments)] pub async fn store(
257 &self,
258 message_id: MessageId,
259 conversation_id: ConversationId,
260 role: &str,
261 vector: Vec<f32>,
262 kind: MessageKind,
263 model: &str,
264 chunk_index: u32,
265 ) -> Result<String, MemoryError> {
266 let point_id = uuid::Uuid::new_v4().to_string();
267 let dimensions = i64::try_from(vector.len())?;
268
269 let payload = std::collections::HashMap::from([
270 ("message_id".to_owned(), serde_json::json!(message_id.0)),
271 (
272 "conversation_id".to_owned(),
273 serde_json::json!(conversation_id.0),
274 ),
275 ("role".to_owned(), serde_json::json!(role)),
276 (
277 "is_summary".to_owned(),
278 serde_json::json!(kind.is_summary()),
279 ),
280 ]);
281
282 let point = VectorPoint {
283 id: point_id.clone(),
284 vector,
285 payload,
286 };
287
288 self.ops.upsert(&self.collection, vec![point]).await?;
289
290 let chunk_index_i64 = i64::from(chunk_index);
291 zeph_db::query(sql!(
292 "INSERT INTO embeddings_metadata \
293 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
294 VALUES (?, ?, ?, ?, ?) \
295 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
296 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
297 ))
298 .bind(message_id)
299 .bind(chunk_index_i64)
300 .bind(&point_id)
301 .bind(dimensions)
302 .bind(model)
303 .execute(&self.pool)
304 .await?;
305
306 Ok(point_id)
307 }
308
309 #[allow(clippy::too_many_arguments)] pub async fn store_with_category(
324 &self,
325 message_id: MessageId,
326 conversation_id: ConversationId,
327 role: &str,
328 vector: Vec<f32>,
329 kind: MessageKind,
330 model: &str,
331 chunk_index: u32,
332 category: Option<&str>,
333 ) -> Result<String, MemoryError> {
334 let point_id = uuid::Uuid::new_v4().to_string();
335 let dimensions = i64::try_from(vector.len())?;
336
337 let mut payload = std::collections::HashMap::from([
338 ("message_id".to_owned(), serde_json::json!(message_id.0)),
339 (
340 "conversation_id".to_owned(),
341 serde_json::json!(conversation_id.0),
342 ),
343 ("role".to_owned(), serde_json::json!(role)),
344 (
345 "is_summary".to_owned(),
346 serde_json::json!(kind.is_summary()),
347 ),
348 ]);
349 if let Some(cat) = category {
350 payload.insert("category".to_owned(), serde_json::json!(cat));
351 }
352
353 let point = VectorPoint {
354 id: point_id.clone(),
355 vector,
356 payload,
357 };
358
359 self.ops.upsert(&self.collection, vec![point]).await?;
360
361 let chunk_index_i64 = i64::from(chunk_index);
362 zeph_db::query(sql!(
363 "INSERT INTO embeddings_metadata \
364 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
365 VALUES (?, ?, ?, ?, ?) \
366 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
367 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
368 ))
369 .bind(message_id)
370 .bind(chunk_index_i64)
371 .bind(&point_id)
372 .bind(dimensions)
373 .bind(model)
374 .execute(&self.pool)
375 .await?;
376
377 Ok(point_id)
378 }
379
380 pub async fn search(
386 &self,
387 query_vector: &[f32],
388 limit: usize,
389 filter: Option<SearchFilter>,
390 ) -> Result<Vec<SearchResult>, MemoryError> {
391 let limit_u64 = u64::try_from(limit)?;
392
393 let vector_filter = filter.as_ref().and_then(|f| {
394 let mut must = Vec::new();
395 if let Some(cid) = f.conversation_id {
396 must.push(FieldCondition {
397 field: "conversation_id".into(),
398 value: FieldValue::Integer(cid.0),
399 });
400 }
401 if let Some(ref role) = f.role {
402 must.push(FieldCondition {
403 field: "role".into(),
404 value: FieldValue::Text(role.clone()),
405 });
406 }
407 if let Some(ref category) = f.category {
408 must.push(FieldCondition {
409 field: "category".into(),
410 value: FieldValue::Text(category.clone()),
411 });
412 }
413 if must.is_empty() {
414 None
415 } else {
416 Some(VectorFilter {
417 must,
418 must_not: vec![],
419 })
420 }
421 });
422
423 let results = self
424 .ops
425 .search(
426 &self.collection,
427 query_vector.to_vec(),
428 limit_u64,
429 vector_filter,
430 )
431 .await?;
432
433 let mut best: std::collections::HashMap<MessageId, SearchResult> =
436 std::collections::HashMap::new();
437 for point in results {
438 let Some(message_id) = point
439 .payload
440 .get("message_id")
441 .and_then(serde_json::Value::as_i64)
442 else {
443 continue;
444 };
445 let Some(conversation_id) = point
446 .payload
447 .get("conversation_id")
448 .and_then(serde_json::Value::as_i64)
449 else {
450 continue;
451 };
452 let message_id = MessageId(message_id);
453 let entry = best.entry(message_id).or_insert(SearchResult {
454 message_id,
455 conversation_id: ConversationId(conversation_id),
456 score: f32::NEG_INFINITY,
457 });
458 if point.score > entry.score {
459 entry.score = point.score;
460 }
461 }
462
463 let mut search_results: Vec<SearchResult> = best.into_values().collect();
464 search_results.sort_by(|a, b| {
465 b.score
466 .partial_cmp(&a.score)
467 .unwrap_or(std::cmp::Ordering::Equal)
468 });
469 search_results.truncate(limit);
470
471 Ok(search_results)
472 }
473
474 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
480 self.ops.collection_exists(name).await.map_err(Into::into)
481 }
482
483 pub async fn ensure_named_collection(
489 &self,
490 name: &str,
491 vector_size: u64,
492 ) -> Result<(), MemoryError> {
493 self.ops.ensure_collection(name, vector_size).await?;
494 Ok(())
495 }
496
497 pub async fn store_to_collection(
505 &self,
506 collection: &str,
507 payload: serde_json::Value,
508 vector: Vec<f32>,
509 ) -> Result<String, MemoryError> {
510 let point_id = uuid::Uuid::new_v4().to_string();
511 let payload_map: std::collections::HashMap<String, serde_json::Value> =
512 serde_json::from_value(payload)?;
513 let point = VectorPoint {
514 id: point_id.clone(),
515 vector,
516 payload: payload_map,
517 };
518 self.ops.upsert(collection, vec![point]).await?;
519 Ok(point_id)
520 }
521
522 pub async fn upsert_to_collection(
530 &self,
531 collection: &str,
532 point_id: &str,
533 payload: serde_json::Value,
534 vector: Vec<f32>,
535 ) -> Result<(), MemoryError> {
536 let payload_map: std::collections::HashMap<String, serde_json::Value> =
537 serde_json::from_value(payload)?;
538 let point = VectorPoint {
539 id: point_id.to_owned(),
540 vector,
541 payload: payload_map,
542 };
543 self.ops.upsert(collection, vec![point]).await?;
544 Ok(())
545 }
546
547 pub async fn search_collection(
553 &self,
554 collection: &str,
555 query_vector: &[f32],
556 limit: usize,
557 filter: Option<VectorFilter>,
558 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
559 let limit_u64 = u64::try_from(limit)?;
560 let results = self
561 .ops
562 .search(collection, query_vector.to_vec(), limit_u64, filter)
563 .await?;
564 Ok(results)
565 }
566
567 pub async fn scroll_all_entity_ids(
580 &self,
581 collection: &str,
582 ) -> Result<Vec<(String, i64)>, MemoryError> {
583 let rows = self
584 .ops
585 .scroll_all_with_point_ids(collection, "entity_id_str")
586 .await?;
587 let mut out = Vec::with_capacity(rows.len());
588 for (point_id, fields) in rows {
589 let Some(s) = fields.get("entity_id_str") else {
590 continue;
591 };
592 if let Ok(id) = s.parse::<i64>() {
593 out.push((point_id, id));
594 } else {
595 tracing::debug!(point_id, value = %s, "entity_id_str unparseable, skipping");
596 }
597 }
598 Ok(out)
599 }
600
601 pub async fn delete_from_collection(
610 &self,
611 collection: &str,
612 ids: Vec<String>,
613 ) -> Result<(), MemoryError> {
614 if ids.is_empty() {
615 return Ok(());
616 }
617 self.ops.delete_by_ids(collection, ids).await?;
618 Ok(())
619 }
620
621 pub async fn get_vectors_from_collection(
631 &self,
632 collection: &str,
633 point_ids: &[String],
634 ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
635 if point_ids.is_empty() {
636 return Ok(std::collections::HashMap::new());
637 }
638 match self.ops.get_points(collection, point_ids.to_vec()).await {
639 Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
640 Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
641 Err(e) => Err(MemoryError::VectorStore(e)),
642 }
643 }
644
645 pub async fn get_vectors(
653 &self,
654 ids: &[MessageId],
655 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
656 if ids.is_empty() {
657 return Ok(std::collections::HashMap::new());
658 }
659
660 let placeholders = zeph_db::placeholder_list(1, ids.len());
661 let query = format!(
662 "SELECT em.message_id, vp.vector \
663 FROM embeddings_metadata em \
664 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
665 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
666 );
667 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
668 for &id in ids {
669 q = q.bind(id);
670 }
671
672 let rows = q.fetch_all(&self.pool).await?;
673
674 let map = rows
675 .into_iter()
676 .filter_map(|(msg_id, blob)| {
677 if blob.len() % 4 != 0 {
678 return None;
679 }
680 let vec: Vec<f32> = blob
681 .chunks_exact(4)
682 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
683 .collect();
684 Some((msg_id, vec))
685 })
686 .collect();
687
688 Ok(map)
689 }
690
691 pub async fn get_vectors_for_messages(
706 &self,
707 ids: &[MessageId],
708 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
709 if ids.is_empty() {
710 return Ok(std::collections::HashMap::new());
711 }
712
713 let placeholders = zeph_db::placeholder_list(1, ids.len());
714 let query = format!(
715 "SELECT message_id, qdrant_point_id \
716 FROM embeddings_metadata \
717 WHERE message_id IN ({placeholders}) AND chunk_index = 0"
718 );
719 let mut q = zeph_db::query_as::<_, (MessageId, String)>(&query);
720 for &id in ids {
721 q = q.bind(id);
722 }
723 let rows: Vec<(MessageId, String)> = q.fetch_all(&self.pool).await?;
724
725 if rows.is_empty() {
726 return Ok(std::collections::HashMap::new());
727 }
728
729 let mut point_to_msg: std::collections::HashMap<String, MessageId> =
731 std::collections::HashMap::with_capacity(rows.len());
732 let point_ids: Vec<String> = rows
733 .into_iter()
734 .map(|(msg_id, point_id)| {
735 point_to_msg.insert(point_id.clone(), msg_id);
736 point_id
737 })
738 .collect();
739
740 let points = match self.ops.get_points(&self.collection, point_ids).await {
741 Ok(pts) => pts,
742 Err(crate::VectorStoreError::Unsupported(_)) => {
743 return Ok(std::collections::HashMap::new());
744 }
745 Err(e) => return Err(MemoryError::VectorStore(e)),
746 };
747
748 let result = points
749 .into_iter()
750 .filter_map(|p| {
751 let msg_id = point_to_msg.get(&p.id).copied()?;
752 Some((msg_id, p.vector))
753 })
754 .collect();
755
756 Ok(result)
757 }
758
759 pub async fn delete_by_message_ids(&self, ids: &[MessageId]) -> Result<usize, MemoryError> {
773 if ids.is_empty() {
774 return Ok(0);
775 }
776
777 let placeholders = zeph_db::placeholder_list(1, ids.len());
778 let query = format!(
779 "SELECT qdrant_point_id FROM embeddings_metadata WHERE message_id IN ({placeholders})"
780 );
781 let mut q = zeph_db::query_as::<_, (String,)>(&query);
782 for &id in ids {
783 q = q.bind(id);
784 }
785 let rows: Vec<(String,)> = q.fetch_all(&self.pool).await?;
786
787 let point_ids: Vec<String> = rows.into_iter().map(|(id,)| id).collect();
788 let count = point_ids.len();
789
790 if !point_ids.is_empty() {
791 self.ops.delete_by_ids(&self.collection, point_ids).await?;
792 }
793
794 Ok(count)
795 }
796
797 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
803 let row: (i64,) = zeph_db::query_as(sql!(
804 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
805 ))
806 .bind(message_id)
807 .fetch_one(&self.pool)
808 .await?;
809
810 Ok(row.0 > 0)
811 }
812
813 pub async fn is_epoch_current(
823 &self,
824 entity_name: &str,
825 qdrant_epoch: u64,
826 ) -> Result<bool, MemoryError> {
827 let row: Option<(i64,)> = zeph_db::query_as(sql!(
828 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
829 ))
830 .bind(entity_name)
831 .fetch_optional(&self.pool)
832 .await?;
833
834 match row {
835 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
837 }
838 }
839}
840
841#[cfg(test)]
842mod tests {
843 use super::*;
844 use crate::in_memory_store::InMemoryVectorStore;
845 use crate::store::SqliteStore;
846
847 async fn setup() -> (SqliteStore, DbPool) {
848 let store = SqliteStore::new(":memory:").await.unwrap();
849 let pool = store.pool().clone();
850 (store, pool)
851 }
852
853 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
854 let sqlite = SqliteStore::new(":memory:").await.unwrap();
855 let pool = sqlite.pool().clone();
856 let mem_store = Box::new(InMemoryVectorStore::new());
857 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
858 embedding_store.ensure_collection(4).await.unwrap();
860 (embedding_store, sqlite)
861 }
862
863 #[tokio::test]
864 async fn has_embedding_returns_false_when_none() {
865 let (_store, pool) = setup().await;
866
867 let row: (i64,) = zeph_db::query_as(sql!(
868 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
869 ))
870 .bind(999_i64)
871 .fetch_one(&pool)
872 .await
873 .unwrap();
874
875 assert_eq!(row.0, 0);
876 }
877
878 #[tokio::test]
879 async fn insert_and_query_embeddings_metadata() {
880 let (sqlite, pool) = setup().await;
881 let cid = sqlite.create_conversation().await.unwrap();
882 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
883
884 let point_id = uuid::Uuid::new_v4().to_string();
885 zeph_db::query(sql!(
886 "INSERT INTO embeddings_metadata \
887 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
888 VALUES (?, ?, ?, ?, ?)"
889 ))
890 .bind(msg_id)
891 .bind(0_i64)
892 .bind(&point_id)
893 .bind(768_i64)
894 .bind("qwen3-embedding")
895 .execute(&pool)
896 .await
897 .unwrap();
898
899 let row: (i64,) = zeph_db::query_as(sql!(
900 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
901 ))
902 .bind(msg_id)
903 .fetch_one(&pool)
904 .await
905 .unwrap();
906 assert_eq!(row.0, 1);
907 }
908
909 #[tokio::test]
910 async fn embedding_store_search_empty_returns_empty() {
911 let (store, _sqlite) = setup_with_store().await;
912 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
913 assert!(results.is_empty());
914 }
915
916 #[tokio::test]
917 async fn embedding_store_store_and_search() {
918 let (store, sqlite) = setup_with_store().await;
919 let cid = sqlite.create_conversation().await.unwrap();
920 let msg_id = sqlite
921 .save_message(cid, "user", "test message")
922 .await
923 .unwrap();
924
925 store
926 .store(
927 msg_id,
928 cid,
929 "user",
930 vec![1.0, 0.0, 0.0, 0.0],
931 MessageKind::Regular,
932 "test-model",
933 0,
934 )
935 .await
936 .unwrap();
937
938 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
939 assert_eq!(results.len(), 1);
940 assert_eq!(results[0].message_id, msg_id);
941 assert_eq!(results[0].conversation_id, cid);
942 assert!((results[0].score - 1.0).abs() < 0.001);
943 }
944
945 #[tokio::test]
946 async fn embedding_store_has_embedding_false_for_unknown() {
947 let (store, sqlite) = setup_with_store().await;
948 let cid = sqlite.create_conversation().await.unwrap();
949 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
950 assert!(!store.has_embedding(msg_id).await.unwrap());
951 }
952
953 #[tokio::test]
954 async fn embedding_store_has_embedding_true_after_store() {
955 let (store, sqlite) = setup_with_store().await;
956 let cid = sqlite.create_conversation().await.unwrap();
957 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
958
959 store
960 .store(
961 msg_id,
962 cid,
963 "user",
964 vec![0.0, 1.0, 0.0, 0.0],
965 MessageKind::Regular,
966 "test-model",
967 0,
968 )
969 .await
970 .unwrap();
971
972 assert!(store.has_embedding(msg_id).await.unwrap());
973 }
974
975 #[tokio::test]
976 async fn embedding_store_search_with_conversation_filter() {
977 let (store, sqlite) = setup_with_store().await;
978 let cid1 = sqlite.create_conversation().await.unwrap();
979 let cid2 = sqlite.create_conversation().await.unwrap();
980 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
981 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
982
983 store
984 .store(
985 msg1,
986 cid1,
987 "user",
988 vec![1.0, 0.0, 0.0, 0.0],
989 MessageKind::Regular,
990 "m",
991 0,
992 )
993 .await
994 .unwrap();
995 store
996 .store(
997 msg2,
998 cid2,
999 "user",
1000 vec![1.0, 0.0, 0.0, 0.0],
1001 MessageKind::Regular,
1002 "m",
1003 0,
1004 )
1005 .await
1006 .unwrap();
1007
1008 let results = store
1009 .search(
1010 &[1.0, 0.0, 0.0, 0.0],
1011 10,
1012 Some(SearchFilter {
1013 conversation_id: Some(cid1),
1014 role: None,
1015 category: None,
1016 }),
1017 )
1018 .await
1019 .unwrap();
1020 assert_eq!(results.len(), 1);
1021 assert_eq!(results[0].conversation_id, cid1);
1022 }
1023
1024 #[tokio::test]
1025 async fn unique_constraint_on_message_chunk_and_model() {
1026 let (sqlite, pool) = setup().await;
1027 let cid = sqlite.create_conversation().await.unwrap();
1028 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1029
1030 let point_id1 = uuid::Uuid::new_v4().to_string();
1031 zeph_db::query(sql!(
1032 "INSERT INTO embeddings_metadata \
1033 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1034 VALUES (?, ?, ?, ?, ?)"
1035 ))
1036 .bind(msg_id)
1037 .bind(0_i64)
1038 .bind(&point_id1)
1039 .bind(768_i64)
1040 .bind("qwen3-embedding")
1041 .execute(&pool)
1042 .await
1043 .unwrap();
1044
1045 let point_id2 = uuid::Uuid::new_v4().to_string();
1047 let result = zeph_db::query(sql!(
1048 "INSERT INTO embeddings_metadata \
1049 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1050 VALUES (?, ?, ?, ?, ?)"
1051 ))
1052 .bind(msg_id)
1053 .bind(0_i64)
1054 .bind(&point_id2)
1055 .bind(768_i64)
1056 .bind("qwen3-embedding")
1057 .execute(&pool)
1058 .await;
1059 assert!(result.is_err());
1060
1061 let point_id3 = uuid::Uuid::new_v4().to_string();
1063 zeph_db::query(sql!(
1064 "INSERT INTO embeddings_metadata \
1065 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1066 VALUES (?, ?, ?, ?, ?)"
1067 ))
1068 .bind(msg_id)
1069 .bind(1_i64)
1070 .bind(&point_id3)
1071 .bind(768_i64)
1072 .bind("qwen3-embedding")
1073 .execute(&pool)
1074 .await
1075 .unwrap();
1076 }
1077
1078 #[tokio::test]
1079 async fn get_vectors_for_messages_returns_correct_vectors() {
1080 let (store, sqlite) = setup_with_store().await;
1081 let cid = sqlite.create_conversation().await.unwrap();
1082 let msg1 = sqlite.save_message(cid, "user", "hello").await.unwrap();
1083 let msg2 = sqlite.save_message(cid, "user", "world").await.unwrap();
1084
1085 store
1086 .store(
1087 msg1,
1088 cid,
1089 "user",
1090 vec![1.0, 0.0, 0.0, 0.0],
1091 MessageKind::Regular,
1092 "m",
1093 0,
1094 )
1095 .await
1096 .unwrap();
1097 store
1098 .store(
1099 msg2,
1100 cid,
1101 "user",
1102 vec![0.0, 1.0, 0.0, 0.0],
1103 MessageKind::Regular,
1104 "m",
1105 0,
1106 )
1107 .await
1108 .unwrap();
1109
1110 let result = store.get_vectors_for_messages(&[msg1, msg2]).await.unwrap();
1111 assert_eq!(result.len(), 2);
1112 let v1 = result.get(&msg1).unwrap();
1113 let v2 = result.get(&msg2).unwrap();
1114 assert!((v1[0] - 1.0).abs() < f32::EPSILON);
1115 assert!((v2[1] - 1.0).abs() < f32::EPSILON);
1116 }
1117
1118 #[tokio::test]
1119 async fn get_vectors_for_messages_missing_id_is_dropped() {
1120 let (store, sqlite) = setup_with_store().await;
1121 let cid = sqlite.create_conversation().await.unwrap();
1122 let msg1 = sqlite.save_message(cid, "user", "present").await.unwrap();
1123 let msg_absent = MessageId(99_999);
1124
1125 store
1126 .store(
1127 msg1,
1128 cid,
1129 "user",
1130 vec![1.0, 0.0, 0.0, 0.0],
1131 MessageKind::Regular,
1132 "m",
1133 0,
1134 )
1135 .await
1136 .unwrap();
1137
1138 let result = store
1139 .get_vectors_for_messages(&[msg1, msg_absent])
1140 .await
1141 .unwrap();
1142 assert_eq!(result.len(), 1);
1143 assert!(result.contains_key(&msg1));
1144 assert!(!result.contains_key(&msg_absent));
1145 }
1146
1147 #[tokio::test]
1148 async fn get_vectors_for_messages_empty_input() {
1149 let (store, _sqlite) = setup_with_store().await;
1150 let result = store.get_vectors_for_messages(&[]).await.unwrap();
1151 assert!(result.is_empty());
1152 }
1153
1154 #[tokio::test]
1155 async fn get_vectors_for_messages_chunk_index_0_only() {
1156 let (store, sqlite) = setup_with_store().await;
1158 let cid = sqlite.create_conversation().await.unwrap();
1159 let msg = sqlite.save_message(cid, "user", "chunked").await.unwrap();
1160
1161 store
1162 .store(
1163 msg,
1164 cid,
1165 "user",
1166 vec![1.0, 0.0, 0.0, 0.0],
1167 MessageKind::Regular,
1168 "m",
1169 0,
1170 )
1171 .await
1172 .unwrap();
1173 store
1174 .store(
1175 msg,
1176 cid,
1177 "user",
1178 vec![0.0, 0.0, 1.0, 0.0],
1179 MessageKind::Regular,
1180 "m",
1181 1,
1182 )
1183 .await
1184 .unwrap();
1185
1186 let result = store.get_vectors_for_messages(&[msg]).await.unwrap();
1187 assert_eq!(result.len(), 1);
1188 let v = result.get(&msg).unwrap();
1190 assert!(
1191 (v[0] - 1.0).abs() < f32::EPSILON,
1192 "expected chunk_index=0 vector"
1193 );
1194 }
1195
1196 #[tokio::test]
1203 async fn embedding_store_delete_by_message_ids_resolves_via_metadata() {
1204 let (store, sqlite) = setup_with_store().await;
1205 let cid = sqlite.create_conversation().await.unwrap();
1206 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1207
1208 store
1210 .store(
1211 msg_id,
1212 cid,
1213 "user",
1214 vec![1.0, 0.0, 0.0, 0.0],
1215 MessageKind::Regular,
1216 "test-model",
1217 0,
1218 )
1219 .await
1220 .unwrap();
1221
1222 assert!(store.has_embedding(msg_id).await.unwrap());
1224
1225 let deleted = store.delete_by_message_ids(&[msg_id]).await.unwrap();
1227 assert_eq!(deleted, 1, "one point id should have been targeted");
1228
1229 let pool = sqlite.pool().clone();
1231 let row: (i64,) = zeph_db::query_as(sql!(
1232 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
1233 ))
1234 .bind(msg_id)
1235 .fetch_one(&pool)
1236 .await
1237 .unwrap();
1238 assert_eq!(
1239 row.0, 1,
1240 "embeddings_metadata row must survive delete_by_message_ids"
1241 );
1242 }
1243
1244 #[tokio::test]
1246 async fn embedding_store_delete_by_message_ids_empty_slice_is_noop() {
1247 let (store, _sqlite) = setup_with_store().await;
1248 let deleted = store.delete_by_message_ids(&[]).await.unwrap();
1249 assert_eq!(deleted, 0);
1250 }
1251}