1pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30 Regular,
32 Summary,
34}
35
36impl MessageKind {
37 #[must_use]
38 pub fn is_summary(self) -> bool {
39 matches!(self, Self::Summary)
40 }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45pub async fn ensure_qdrant_collection(
53 ops: &QdrantOps,
54 collection: &str,
55 vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57 ops.ensure_collection(collection, vector_size).await
58}
59
60pub struct EmbeddingStore {
65 ops: Box<dyn VectorStore>,
66 collection: String,
67 pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("EmbeddingStore")
73 .field("collection", &self.collection)
74 .finish_non_exhaustive()
75 }
76}
77
78#[derive(Debug)]
80pub struct SearchFilter {
81 pub conversation_id: Option<ConversationId>,
83 pub role: Option<String>,
85 pub category: Option<String>,
88}
89
90#[derive(Debug)]
92pub struct SearchResult {
93 pub message_id: MessageId,
95 pub conversation_id: ConversationId,
97 pub score: f32,
99}
100
101impl EmbeddingStore {
102 pub fn new(url: &str, api_key: Option<&str>, pool: DbPool) -> Result<Self, MemoryError> {
112 let ops = QdrantOps::new(url, api_key).map_err(MemoryError::Qdrant)?;
113
114 Ok(Self {
115 ops: Box::new(ops),
116 collection: COLLECTION_NAME.into(),
117 pool,
118 })
119 }
120
121 #[must_use]
125 pub fn new_sqlite(pool: DbPool) -> Self {
126 let ops = DbVectorStore::new(pool.clone());
127 Self {
128 ops: Box::new(ops),
129 collection: COLLECTION_NAME.into(),
130 pool,
131 }
132 }
133
134 #[must_use]
135 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
136 Self {
137 ops: store,
138 collection: COLLECTION_NAME.into(),
139 pool,
140 }
141 }
142
143 pub async fn health_check(&self) -> bool {
144 self.ops.health_check().await.unwrap_or(false)
145 }
146
147 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
155 self.ops
156 .ensure_collection(&self.collection, vector_size)
157 .await?;
158 self.ops
161 .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
162 .await?;
163 Ok(())
164 }
165
166 #[allow(clippy::too_many_arguments)] pub async fn store_with_tool_context(
177 &self,
178 message_id: MessageId,
179 conversation_id: ConversationId,
180 role: &str,
181 vector: Vec<f32>,
182 kind: MessageKind,
183 model: &str,
184 chunk_index: u32,
185 tool_name: &str,
186 exit_code: Option<i32>,
187 timestamp: Option<&str>,
188 ) -> Result<String, MemoryError> {
189 let point_id = uuid::Uuid::new_v4().to_string();
190 let dimensions = i64::try_from(vector.len())?;
191
192 let mut payload = std::collections::HashMap::from([
193 ("message_id".to_owned(), serde_json::json!(message_id.0)),
194 (
195 "conversation_id".to_owned(),
196 serde_json::json!(conversation_id.0),
197 ),
198 ("role".to_owned(), serde_json::json!(role)),
199 (
200 "is_summary".to_owned(),
201 serde_json::json!(kind.is_summary()),
202 ),
203 ("tool_name".to_owned(), serde_json::json!(tool_name)),
204 ]);
205 if let Some(code) = exit_code {
206 payload.insert("exit_code".to_owned(), serde_json::json!(code));
207 }
208 if let Some(ts) = timestamp {
209 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
210 }
211
212 let point = VectorPoint {
213 id: point_id.clone(),
214 vector,
215 payload,
216 };
217
218 self.ops.upsert(&self.collection, vec![point]).await?;
219
220 let chunk_index_i64 = i64::from(chunk_index);
221 zeph_db::query(sql!(
222 "INSERT INTO embeddings_metadata \
223 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
224 VALUES (?, ?, ?, ?, ?) \
225 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
226 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
227 ))
228 .bind(message_id)
229 .bind(chunk_index_i64)
230 .bind(&point_id)
231 .bind(dimensions)
232 .bind(model)
233 .execute(&self.pool)
234 .await?;
235
236 Ok(point_id)
237 }
238
239 #[allow(clippy::too_many_arguments)] pub async fn store(
251 &self,
252 message_id: MessageId,
253 conversation_id: ConversationId,
254 role: &str,
255 vector: Vec<f32>,
256 kind: MessageKind,
257 model: &str,
258 chunk_index: u32,
259 ) -> Result<String, MemoryError> {
260 let point_id = uuid::Uuid::new_v4().to_string();
261 let dimensions = i64::try_from(vector.len())?;
262
263 let payload = std::collections::HashMap::from([
264 ("message_id".to_owned(), serde_json::json!(message_id.0)),
265 (
266 "conversation_id".to_owned(),
267 serde_json::json!(conversation_id.0),
268 ),
269 ("role".to_owned(), serde_json::json!(role)),
270 (
271 "is_summary".to_owned(),
272 serde_json::json!(kind.is_summary()),
273 ),
274 ]);
275
276 let point = VectorPoint {
277 id: point_id.clone(),
278 vector,
279 payload,
280 };
281
282 self.ops.upsert(&self.collection, vec![point]).await?;
283
284 let chunk_index_i64 = i64::from(chunk_index);
285 zeph_db::query(sql!(
286 "INSERT INTO embeddings_metadata \
287 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
288 VALUES (?, ?, ?, ?, ?) \
289 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
290 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
291 ))
292 .bind(message_id)
293 .bind(chunk_index_i64)
294 .bind(&point_id)
295 .bind(dimensions)
296 .bind(model)
297 .execute(&self.pool)
298 .await?;
299
300 Ok(point_id)
301 }
302
303 #[allow(clippy::too_many_arguments)] pub async fn store_with_category(
318 &self,
319 message_id: MessageId,
320 conversation_id: ConversationId,
321 role: &str,
322 vector: Vec<f32>,
323 kind: MessageKind,
324 model: &str,
325 chunk_index: u32,
326 category: Option<&str>,
327 ) -> Result<String, MemoryError> {
328 let point_id = uuid::Uuid::new_v4().to_string();
329 let dimensions = i64::try_from(vector.len())?;
330
331 let mut payload = std::collections::HashMap::from([
332 ("message_id".to_owned(), serde_json::json!(message_id.0)),
333 (
334 "conversation_id".to_owned(),
335 serde_json::json!(conversation_id.0),
336 ),
337 ("role".to_owned(), serde_json::json!(role)),
338 (
339 "is_summary".to_owned(),
340 serde_json::json!(kind.is_summary()),
341 ),
342 ]);
343 if let Some(cat) = category {
344 payload.insert("category".to_owned(), serde_json::json!(cat));
345 }
346
347 let point = VectorPoint {
348 id: point_id.clone(),
349 vector,
350 payload,
351 };
352
353 self.ops.upsert(&self.collection, vec![point]).await?;
354
355 let chunk_index_i64 = i64::from(chunk_index);
356 zeph_db::query(sql!(
357 "INSERT INTO embeddings_metadata \
358 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
359 VALUES (?, ?, ?, ?, ?) \
360 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
361 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
362 ))
363 .bind(message_id)
364 .bind(chunk_index_i64)
365 .bind(&point_id)
366 .bind(dimensions)
367 .bind(model)
368 .execute(&self.pool)
369 .await?;
370
371 Ok(point_id)
372 }
373
374 pub async fn search(
380 &self,
381 query_vector: &[f32],
382 limit: usize,
383 filter: Option<SearchFilter>,
384 ) -> Result<Vec<SearchResult>, MemoryError> {
385 let limit_u64 = u64::try_from(limit)?;
386
387 let vector_filter = filter.as_ref().and_then(|f| {
388 let mut must = Vec::new();
389 if let Some(cid) = f.conversation_id {
390 must.push(FieldCondition {
391 field: "conversation_id".into(),
392 value: FieldValue::Integer(cid.0),
393 });
394 }
395 if let Some(ref role) = f.role {
396 must.push(FieldCondition {
397 field: "role".into(),
398 value: FieldValue::Text(role.clone()),
399 });
400 }
401 if let Some(ref category) = f.category {
402 must.push(FieldCondition {
403 field: "category".into(),
404 value: FieldValue::Text(category.clone()),
405 });
406 }
407 if must.is_empty() {
408 None
409 } else {
410 Some(VectorFilter {
411 must,
412 must_not: vec![],
413 })
414 }
415 });
416
417 let results = self
418 .ops
419 .search(
420 &self.collection,
421 query_vector.to_vec(),
422 limit_u64,
423 vector_filter,
424 )
425 .await?;
426
427 let mut best: std::collections::HashMap<MessageId, SearchResult> =
430 std::collections::HashMap::new();
431 for point in results {
432 let Some(message_id) = point
433 .payload
434 .get("message_id")
435 .and_then(serde_json::Value::as_i64)
436 else {
437 continue;
438 };
439 let Some(conversation_id) = point
440 .payload
441 .get("conversation_id")
442 .and_then(serde_json::Value::as_i64)
443 else {
444 continue;
445 };
446 let message_id = MessageId(message_id);
447 let entry = best.entry(message_id).or_insert(SearchResult {
448 message_id,
449 conversation_id: ConversationId(conversation_id),
450 score: f32::NEG_INFINITY,
451 });
452 if point.score > entry.score {
453 entry.score = point.score;
454 }
455 }
456
457 let mut search_results: Vec<SearchResult> = best.into_values().collect();
458 search_results.sort_by(|a, b| {
459 b.score
460 .partial_cmp(&a.score)
461 .unwrap_or(std::cmp::Ordering::Equal)
462 });
463 search_results.truncate(limit);
464
465 Ok(search_results)
466 }
467
468 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
474 self.ops.collection_exists(name).await.map_err(Into::into)
475 }
476
477 pub async fn ensure_named_collection(
483 &self,
484 name: &str,
485 vector_size: u64,
486 ) -> Result<(), MemoryError> {
487 self.ops.ensure_collection(name, vector_size).await?;
488 Ok(())
489 }
490
491 pub async fn store_to_collection(
499 &self,
500 collection: &str,
501 payload: serde_json::Value,
502 vector: Vec<f32>,
503 ) -> Result<String, MemoryError> {
504 let point_id = uuid::Uuid::new_v4().to_string();
505 let payload_map: std::collections::HashMap<String, serde_json::Value> =
506 serde_json::from_value(payload)?;
507 let point = VectorPoint {
508 id: point_id.clone(),
509 vector,
510 payload: payload_map,
511 };
512 self.ops.upsert(collection, vec![point]).await?;
513 Ok(point_id)
514 }
515
516 pub async fn upsert_to_collection(
524 &self,
525 collection: &str,
526 point_id: &str,
527 payload: serde_json::Value,
528 vector: Vec<f32>,
529 ) -> Result<(), MemoryError> {
530 let payload_map: std::collections::HashMap<String, serde_json::Value> =
531 serde_json::from_value(payload)?;
532 let point = VectorPoint {
533 id: point_id.to_owned(),
534 vector,
535 payload: payload_map,
536 };
537 self.ops.upsert(collection, vec![point]).await?;
538 Ok(())
539 }
540
541 pub async fn search_collection(
547 &self,
548 collection: &str,
549 query_vector: &[f32],
550 limit: usize,
551 filter: Option<VectorFilter>,
552 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
553 let limit_u64 = u64::try_from(limit)?;
554 let results = self
555 .ops
556 .search(collection, query_vector.to_vec(), limit_u64, filter)
557 .await?;
558 Ok(results)
559 }
560
561 pub async fn scroll_all_entity_ids(
574 &self,
575 collection: &str,
576 ) -> Result<Vec<(String, i64)>, MemoryError> {
577 let rows = self
578 .ops
579 .scroll_all_with_point_ids(collection, "entity_id_str")
580 .await?;
581 let mut out = Vec::with_capacity(rows.len());
582 for (point_id, fields) in rows {
583 let Some(s) = fields.get("entity_id_str") else {
584 continue;
585 };
586 if let Ok(id) = s.parse::<i64>() {
587 out.push((point_id, id));
588 } else {
589 tracing::debug!(point_id, value = %s, "entity_id_str unparseable, skipping");
590 }
591 }
592 Ok(out)
593 }
594
595 pub async fn delete_from_collection(
604 &self,
605 collection: &str,
606 ids: Vec<String>,
607 ) -> Result<(), MemoryError> {
608 if ids.is_empty() {
609 return Ok(());
610 }
611 self.ops.delete_by_ids(collection, ids).await?;
612 Ok(())
613 }
614
615 pub async fn get_vectors_from_collection(
625 &self,
626 collection: &str,
627 point_ids: &[String],
628 ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
629 if point_ids.is_empty() {
630 return Ok(std::collections::HashMap::new());
631 }
632 match self.ops.get_points(collection, point_ids.to_vec()).await {
633 Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
634 Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
635 Err(e) => Err(MemoryError::VectorStore(e)),
636 }
637 }
638
639 pub async fn get_vectors(
647 &self,
648 ids: &[MessageId],
649 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
650 if ids.is_empty() {
651 return Ok(std::collections::HashMap::new());
652 }
653
654 let placeholders = zeph_db::placeholder_list(1, ids.len());
655 let query = format!(
656 "SELECT em.message_id, vp.vector \
657 FROM embeddings_metadata em \
658 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
659 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
660 );
661 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
662 for &id in ids {
663 q = q.bind(id);
664 }
665
666 let rows = q.fetch_all(&self.pool).await?;
667
668 let map = rows
669 .into_iter()
670 .filter_map(|(msg_id, blob)| {
671 if blob.len() % 4 != 0 {
672 return None;
673 }
674 let vec: Vec<f32> = blob
675 .chunks_exact(4)
676 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
677 .collect();
678 Some((msg_id, vec))
679 })
680 .collect();
681
682 Ok(map)
683 }
684
685 pub async fn get_vectors_for_messages(
700 &self,
701 ids: &[MessageId],
702 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
703 if ids.is_empty() {
704 return Ok(std::collections::HashMap::new());
705 }
706
707 let placeholders = zeph_db::placeholder_list(1, ids.len());
708 let query = format!(
709 "SELECT message_id, qdrant_point_id \
710 FROM embeddings_metadata \
711 WHERE message_id IN ({placeholders}) AND chunk_index = 0"
712 );
713 let mut q = zeph_db::query_as::<_, (MessageId, String)>(&query);
714 for &id in ids {
715 q = q.bind(id);
716 }
717 let rows: Vec<(MessageId, String)> = q.fetch_all(&self.pool).await?;
718
719 if rows.is_empty() {
720 return Ok(std::collections::HashMap::new());
721 }
722
723 let mut point_to_msg: std::collections::HashMap<String, MessageId> =
725 std::collections::HashMap::with_capacity(rows.len());
726 let point_ids: Vec<String> = rows
727 .into_iter()
728 .map(|(msg_id, point_id)| {
729 point_to_msg.insert(point_id.clone(), msg_id);
730 point_id
731 })
732 .collect();
733
734 let points = match self.ops.get_points(&self.collection, point_ids).await {
735 Ok(pts) => pts,
736 Err(crate::VectorStoreError::Unsupported(_)) => {
737 return Ok(std::collections::HashMap::new());
738 }
739 Err(e) => return Err(MemoryError::VectorStore(e)),
740 };
741
742 let result = points
743 .into_iter()
744 .filter_map(|p| {
745 let msg_id = point_to_msg.get(&p.id).copied()?;
746 Some((msg_id, p.vector))
747 })
748 .collect();
749
750 Ok(result)
751 }
752
753 pub async fn delete_by_message_ids(&self, ids: &[MessageId]) -> Result<usize, MemoryError> {
767 if ids.is_empty() {
768 return Ok(0);
769 }
770
771 let placeholders = zeph_db::placeholder_list(1, ids.len());
772 let query = format!(
773 "SELECT qdrant_point_id FROM embeddings_metadata WHERE message_id IN ({placeholders})"
774 );
775 let mut q = zeph_db::query_as::<_, (String,)>(&query);
776 for &id in ids {
777 q = q.bind(id);
778 }
779 let rows: Vec<(String,)> = q.fetch_all(&self.pool).await?;
780
781 let point_ids: Vec<String> = rows.into_iter().map(|(id,)| id).collect();
782 let count = point_ids.len();
783
784 if !point_ids.is_empty() {
785 self.ops.delete_by_ids(&self.collection, point_ids).await?;
786 }
787
788 Ok(count)
789 }
790
791 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
797 let row: (i64,) = zeph_db::query_as(sql!(
798 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
799 ))
800 .bind(message_id)
801 .fetch_one(&self.pool)
802 .await?;
803
804 Ok(row.0 > 0)
805 }
806
807 pub async fn is_epoch_current(
817 &self,
818 entity_name: &str,
819 qdrant_epoch: u64,
820 ) -> Result<bool, MemoryError> {
821 let row: Option<(i64,)> = zeph_db::query_as(sql!(
822 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
823 ))
824 .bind(entity_name)
825 .fetch_optional(&self.pool)
826 .await?;
827
828 match row {
829 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
831 }
832 }
833}
834
835#[cfg(test)]
836mod tests {
837 use super::*;
838 use crate::in_memory_store::InMemoryVectorStore;
839 use crate::store::SqliteStore;
840
841 async fn setup() -> (SqliteStore, DbPool) {
842 let store = SqliteStore::new(":memory:").await.unwrap();
843 let pool = store.pool().clone();
844 (store, pool)
845 }
846
847 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
848 let sqlite = SqliteStore::new(":memory:").await.unwrap();
849 let pool = sqlite.pool().clone();
850 let mem_store = Box::new(InMemoryVectorStore::new());
851 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
852 embedding_store.ensure_collection(4).await.unwrap();
854 (embedding_store, sqlite)
855 }
856
857 #[tokio::test]
858 async fn has_embedding_returns_false_when_none() {
859 let (_store, pool) = setup().await;
860
861 let row: (i64,) = zeph_db::query_as(sql!(
862 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
863 ))
864 .bind(999_i64)
865 .fetch_one(&pool)
866 .await
867 .unwrap();
868
869 assert_eq!(row.0, 0);
870 }
871
872 #[tokio::test]
873 async fn insert_and_query_embeddings_metadata() {
874 let (sqlite, pool) = setup().await;
875 let cid = sqlite.create_conversation().await.unwrap();
876 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
877
878 let point_id = uuid::Uuid::new_v4().to_string();
879 zeph_db::query(sql!(
880 "INSERT INTO embeddings_metadata \
881 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
882 VALUES (?, ?, ?, ?, ?)"
883 ))
884 .bind(msg_id)
885 .bind(0_i64)
886 .bind(&point_id)
887 .bind(768_i64)
888 .bind("qwen3-embedding")
889 .execute(&pool)
890 .await
891 .unwrap();
892
893 let row: (i64,) = zeph_db::query_as(sql!(
894 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
895 ))
896 .bind(msg_id)
897 .fetch_one(&pool)
898 .await
899 .unwrap();
900 assert_eq!(row.0, 1);
901 }
902
903 #[tokio::test]
904 async fn embedding_store_search_empty_returns_empty() {
905 let (store, _sqlite) = setup_with_store().await;
906 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
907 assert!(results.is_empty());
908 }
909
910 #[tokio::test]
911 async fn embedding_store_store_and_search() {
912 let (store, sqlite) = setup_with_store().await;
913 let cid = sqlite.create_conversation().await.unwrap();
914 let msg_id = sqlite
915 .save_message(cid, "user", "test message")
916 .await
917 .unwrap();
918
919 store
920 .store(
921 msg_id,
922 cid,
923 "user",
924 vec![1.0, 0.0, 0.0, 0.0],
925 MessageKind::Regular,
926 "test-model",
927 0,
928 )
929 .await
930 .unwrap();
931
932 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
933 assert_eq!(results.len(), 1);
934 assert_eq!(results[0].message_id, msg_id);
935 assert_eq!(results[0].conversation_id, cid);
936 assert!((results[0].score - 1.0).abs() < 0.001);
937 }
938
939 #[tokio::test]
940 async fn embedding_store_has_embedding_false_for_unknown() {
941 let (store, sqlite) = setup_with_store().await;
942 let cid = sqlite.create_conversation().await.unwrap();
943 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
944 assert!(!store.has_embedding(msg_id).await.unwrap());
945 }
946
947 #[tokio::test]
948 async fn embedding_store_has_embedding_true_after_store() {
949 let (store, sqlite) = setup_with_store().await;
950 let cid = sqlite.create_conversation().await.unwrap();
951 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
952
953 store
954 .store(
955 msg_id,
956 cid,
957 "user",
958 vec![0.0, 1.0, 0.0, 0.0],
959 MessageKind::Regular,
960 "test-model",
961 0,
962 )
963 .await
964 .unwrap();
965
966 assert!(store.has_embedding(msg_id).await.unwrap());
967 }
968
969 #[tokio::test]
970 async fn embedding_store_search_with_conversation_filter() {
971 let (store, sqlite) = setup_with_store().await;
972 let cid1 = sqlite.create_conversation().await.unwrap();
973 let cid2 = sqlite.create_conversation().await.unwrap();
974 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
975 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
976
977 store
978 .store(
979 msg1,
980 cid1,
981 "user",
982 vec![1.0, 0.0, 0.0, 0.0],
983 MessageKind::Regular,
984 "m",
985 0,
986 )
987 .await
988 .unwrap();
989 store
990 .store(
991 msg2,
992 cid2,
993 "user",
994 vec![1.0, 0.0, 0.0, 0.0],
995 MessageKind::Regular,
996 "m",
997 0,
998 )
999 .await
1000 .unwrap();
1001
1002 let results = store
1003 .search(
1004 &[1.0, 0.0, 0.0, 0.0],
1005 10,
1006 Some(SearchFilter {
1007 conversation_id: Some(cid1),
1008 role: None,
1009 category: None,
1010 }),
1011 )
1012 .await
1013 .unwrap();
1014 assert_eq!(results.len(), 1);
1015 assert_eq!(results[0].conversation_id, cid1);
1016 }
1017
1018 #[tokio::test]
1019 async fn unique_constraint_on_message_chunk_and_model() {
1020 let (sqlite, pool) = setup().await;
1021 let cid = sqlite.create_conversation().await.unwrap();
1022 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1023
1024 let point_id1 = uuid::Uuid::new_v4().to_string();
1025 zeph_db::query(sql!(
1026 "INSERT INTO embeddings_metadata \
1027 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1028 VALUES (?, ?, ?, ?, ?)"
1029 ))
1030 .bind(msg_id)
1031 .bind(0_i64)
1032 .bind(&point_id1)
1033 .bind(768_i64)
1034 .bind("qwen3-embedding")
1035 .execute(&pool)
1036 .await
1037 .unwrap();
1038
1039 let point_id2 = uuid::Uuid::new_v4().to_string();
1041 let result = zeph_db::query(sql!(
1042 "INSERT INTO embeddings_metadata \
1043 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1044 VALUES (?, ?, ?, ?, ?)"
1045 ))
1046 .bind(msg_id)
1047 .bind(0_i64)
1048 .bind(&point_id2)
1049 .bind(768_i64)
1050 .bind("qwen3-embedding")
1051 .execute(&pool)
1052 .await;
1053 assert!(result.is_err());
1054
1055 let point_id3 = uuid::Uuid::new_v4().to_string();
1057 zeph_db::query(sql!(
1058 "INSERT INTO embeddings_metadata \
1059 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1060 VALUES (?, ?, ?, ?, ?)"
1061 ))
1062 .bind(msg_id)
1063 .bind(1_i64)
1064 .bind(&point_id3)
1065 .bind(768_i64)
1066 .bind("qwen3-embedding")
1067 .execute(&pool)
1068 .await
1069 .unwrap();
1070 }
1071
1072 #[tokio::test]
1073 async fn get_vectors_for_messages_returns_correct_vectors() {
1074 let (store, sqlite) = setup_with_store().await;
1075 let cid = sqlite.create_conversation().await.unwrap();
1076 let msg1 = sqlite.save_message(cid, "user", "hello").await.unwrap();
1077 let msg2 = sqlite.save_message(cid, "user", "world").await.unwrap();
1078
1079 store
1080 .store(
1081 msg1,
1082 cid,
1083 "user",
1084 vec![1.0, 0.0, 0.0, 0.0],
1085 MessageKind::Regular,
1086 "m",
1087 0,
1088 )
1089 .await
1090 .unwrap();
1091 store
1092 .store(
1093 msg2,
1094 cid,
1095 "user",
1096 vec![0.0, 1.0, 0.0, 0.0],
1097 MessageKind::Regular,
1098 "m",
1099 0,
1100 )
1101 .await
1102 .unwrap();
1103
1104 let result = store.get_vectors_for_messages(&[msg1, msg2]).await.unwrap();
1105 assert_eq!(result.len(), 2);
1106 let v1 = result.get(&msg1).unwrap();
1107 let v2 = result.get(&msg2).unwrap();
1108 assert!((v1[0] - 1.0).abs() < f32::EPSILON);
1109 assert!((v2[1] - 1.0).abs() < f32::EPSILON);
1110 }
1111
1112 #[tokio::test]
1113 async fn get_vectors_for_messages_missing_id_is_dropped() {
1114 let (store, sqlite) = setup_with_store().await;
1115 let cid = sqlite.create_conversation().await.unwrap();
1116 let msg1 = sqlite.save_message(cid, "user", "present").await.unwrap();
1117 let msg_absent = MessageId(99_999);
1118
1119 store
1120 .store(
1121 msg1,
1122 cid,
1123 "user",
1124 vec![1.0, 0.0, 0.0, 0.0],
1125 MessageKind::Regular,
1126 "m",
1127 0,
1128 )
1129 .await
1130 .unwrap();
1131
1132 let result = store
1133 .get_vectors_for_messages(&[msg1, msg_absent])
1134 .await
1135 .unwrap();
1136 assert_eq!(result.len(), 1);
1137 assert!(result.contains_key(&msg1));
1138 assert!(!result.contains_key(&msg_absent));
1139 }
1140
1141 #[tokio::test]
1142 async fn get_vectors_for_messages_empty_input() {
1143 let (store, _sqlite) = setup_with_store().await;
1144 let result = store.get_vectors_for_messages(&[]).await.unwrap();
1145 assert!(result.is_empty());
1146 }
1147
1148 #[tokio::test]
1149 async fn get_vectors_for_messages_chunk_index_0_only() {
1150 let (store, sqlite) = setup_with_store().await;
1152 let cid = sqlite.create_conversation().await.unwrap();
1153 let msg = sqlite.save_message(cid, "user", "chunked").await.unwrap();
1154
1155 store
1156 .store(
1157 msg,
1158 cid,
1159 "user",
1160 vec![1.0, 0.0, 0.0, 0.0],
1161 MessageKind::Regular,
1162 "m",
1163 0,
1164 )
1165 .await
1166 .unwrap();
1167 store
1168 .store(
1169 msg,
1170 cid,
1171 "user",
1172 vec![0.0, 0.0, 1.0, 0.0],
1173 MessageKind::Regular,
1174 "m",
1175 1,
1176 )
1177 .await
1178 .unwrap();
1179
1180 let result = store.get_vectors_for_messages(&[msg]).await.unwrap();
1181 assert_eq!(result.len(), 1);
1182 let v = result.get(&msg).unwrap();
1184 assert!(
1185 (v[0] - 1.0).abs() < f32::EPSILON,
1186 "expected chunk_index=0 vector"
1187 );
1188 }
1189
1190 #[tokio::test]
1197 async fn embedding_store_delete_by_message_ids_resolves_via_metadata() {
1198 let (store, sqlite) = setup_with_store().await;
1199 let cid = sqlite.create_conversation().await.unwrap();
1200 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1201
1202 store
1204 .store(
1205 msg_id,
1206 cid,
1207 "user",
1208 vec![1.0, 0.0, 0.0, 0.0],
1209 MessageKind::Regular,
1210 "test-model",
1211 0,
1212 )
1213 .await
1214 .unwrap();
1215
1216 assert!(store.has_embedding(msg_id).await.unwrap());
1218
1219 let deleted = store.delete_by_message_ids(&[msg_id]).await.unwrap();
1221 assert_eq!(deleted, 1, "one point id should have been targeted");
1222
1223 let pool = sqlite.pool().clone();
1225 let row: (i64,) = zeph_db::query_as(sql!(
1226 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
1227 ))
1228 .bind(msg_id)
1229 .fetch_one(&pool)
1230 .await
1231 .unwrap();
1232 assert_eq!(
1233 row.0, 1,
1234 "embeddings_metadata row must survive delete_by_message_ids"
1235 );
1236 }
1237
1238 #[tokio::test]
1240 async fn embedding_store_delete_by_message_ids_empty_slice_is_noop() {
1241 let (store, _sqlite) = setup_with_store().await;
1242 let deleted = store.delete_by_message_ids(&[]).await.unwrap();
1243 assert_eq!(deleted, 0);
1244 }
1245}