1pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30 Regular,
32 Summary,
34}
35
36impl MessageKind {
37 #[must_use]
38 pub fn is_summary(self) -> bool {
39 matches!(self, Self::Summary)
40 }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45pub async fn ensure_qdrant_collection(
53 ops: &QdrantOps,
54 collection: &str,
55 vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57 ops.ensure_collection(collection, vector_size).await
58}
59
60pub struct EmbeddingStore {
65 ops: Box<dyn VectorStore>,
66 collection: String,
67 pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("EmbeddingStore")
73 .field("collection", &self.collection)
74 .finish_non_exhaustive()
75 }
76}
77
78#[derive(Debug)]
80pub struct SearchFilter {
81 pub conversation_id: Option<ConversationId>,
83 pub role: Option<String>,
85 pub category: Option<String>,
88}
89
90#[derive(Debug)]
92pub struct SearchResult {
93 pub message_id: MessageId,
95 pub conversation_id: ConversationId,
97 pub score: f32,
99}
100
101impl EmbeddingStore {
102 pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
111 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
112
113 Ok(Self {
114 ops: Box::new(ops),
115 collection: COLLECTION_NAME.into(),
116 pool,
117 })
118 }
119
120 #[must_use]
124 pub fn new_sqlite(pool: DbPool) -> Self {
125 let ops = DbVectorStore::new(pool.clone());
126 Self {
127 ops: Box::new(ops),
128 collection: COLLECTION_NAME.into(),
129 pool,
130 }
131 }
132
133 #[must_use]
134 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
135 Self {
136 ops: store,
137 collection: COLLECTION_NAME.into(),
138 pool,
139 }
140 }
141
142 pub async fn health_check(&self) -> bool {
143 self.ops.health_check().await.unwrap_or(false)
144 }
145
146 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
154 self.ops
155 .ensure_collection(&self.collection, vector_size)
156 .await?;
157 self.ops
160 .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
161 .await?;
162 Ok(())
163 }
164
165 #[allow(clippy::too_many_arguments)]
175 pub async fn store_with_tool_context(
176 &self,
177 message_id: MessageId,
178 conversation_id: ConversationId,
179 role: &str,
180 vector: Vec<f32>,
181 kind: MessageKind,
182 model: &str,
183 chunk_index: u32,
184 tool_name: &str,
185 exit_code: Option<i32>,
186 timestamp: Option<&str>,
187 ) -> Result<String, MemoryError> {
188 let point_id = uuid::Uuid::new_v4().to_string();
189 let dimensions = i64::try_from(vector.len())?;
190
191 let mut payload = std::collections::HashMap::from([
192 ("message_id".to_owned(), serde_json::json!(message_id.0)),
193 (
194 "conversation_id".to_owned(),
195 serde_json::json!(conversation_id.0),
196 ),
197 ("role".to_owned(), serde_json::json!(role)),
198 (
199 "is_summary".to_owned(),
200 serde_json::json!(kind.is_summary()),
201 ),
202 ("tool_name".to_owned(), serde_json::json!(tool_name)),
203 ]);
204 if let Some(code) = exit_code {
205 payload.insert("exit_code".to_owned(), serde_json::json!(code));
206 }
207 if let Some(ts) = timestamp {
208 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
209 }
210
211 let point = VectorPoint {
212 id: point_id.clone(),
213 vector,
214 payload,
215 };
216
217 self.ops.upsert(&self.collection, vec![point]).await?;
218
219 let chunk_index_i64 = i64::from(chunk_index);
220 zeph_db::query(sql!(
221 "INSERT INTO embeddings_metadata \
222 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
223 VALUES (?, ?, ?, ?, ?) \
224 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
225 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
226 ))
227 .bind(message_id)
228 .bind(chunk_index_i64)
229 .bind(&point_id)
230 .bind(dimensions)
231 .bind(model)
232 .execute(&self.pool)
233 .await?;
234
235 Ok(point_id)
236 }
237
238 #[allow(clippy::too_many_arguments)]
249 pub async fn store(
250 &self,
251 message_id: MessageId,
252 conversation_id: ConversationId,
253 role: &str,
254 vector: Vec<f32>,
255 kind: MessageKind,
256 model: &str,
257 chunk_index: u32,
258 ) -> Result<String, MemoryError> {
259 let point_id = uuid::Uuid::new_v4().to_string();
260 let dimensions = i64::try_from(vector.len())?;
261
262 let payload = std::collections::HashMap::from([
263 ("message_id".to_owned(), serde_json::json!(message_id.0)),
264 (
265 "conversation_id".to_owned(),
266 serde_json::json!(conversation_id.0),
267 ),
268 ("role".to_owned(), serde_json::json!(role)),
269 (
270 "is_summary".to_owned(),
271 serde_json::json!(kind.is_summary()),
272 ),
273 ]);
274
275 let point = VectorPoint {
276 id: point_id.clone(),
277 vector,
278 payload,
279 };
280
281 self.ops.upsert(&self.collection, vec![point]).await?;
282
283 let chunk_index_i64 = i64::from(chunk_index);
284 zeph_db::query(sql!(
285 "INSERT INTO embeddings_metadata \
286 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
287 VALUES (?, ?, ?, ?, ?) \
288 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
289 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
290 ))
291 .bind(message_id)
292 .bind(chunk_index_i64)
293 .bind(&point_id)
294 .bind(dimensions)
295 .bind(model)
296 .execute(&self.pool)
297 .await?;
298
299 Ok(point_id)
300 }
301
302 #[allow(clippy::too_many_arguments)]
316 pub async fn store_with_category(
317 &self,
318 message_id: MessageId,
319 conversation_id: ConversationId,
320 role: &str,
321 vector: Vec<f32>,
322 kind: MessageKind,
323 model: &str,
324 chunk_index: u32,
325 category: Option<&str>,
326 ) -> Result<String, MemoryError> {
327 let point_id = uuid::Uuid::new_v4().to_string();
328 let dimensions = i64::try_from(vector.len())?;
329
330 let mut payload = std::collections::HashMap::from([
331 ("message_id".to_owned(), serde_json::json!(message_id.0)),
332 (
333 "conversation_id".to_owned(),
334 serde_json::json!(conversation_id.0),
335 ),
336 ("role".to_owned(), serde_json::json!(role)),
337 (
338 "is_summary".to_owned(),
339 serde_json::json!(kind.is_summary()),
340 ),
341 ]);
342 if let Some(cat) = category {
343 payload.insert("category".to_owned(), serde_json::json!(cat));
344 }
345
346 let point = VectorPoint {
347 id: point_id.clone(),
348 vector,
349 payload,
350 };
351
352 self.ops.upsert(&self.collection, vec![point]).await?;
353
354 let chunk_index_i64 = i64::from(chunk_index);
355 zeph_db::query(sql!(
356 "INSERT INTO embeddings_metadata \
357 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
358 VALUES (?, ?, ?, ?, ?) \
359 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
360 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
361 ))
362 .bind(message_id)
363 .bind(chunk_index_i64)
364 .bind(&point_id)
365 .bind(dimensions)
366 .bind(model)
367 .execute(&self.pool)
368 .await?;
369
370 Ok(point_id)
371 }
372
373 pub async fn search(
379 &self,
380 query_vector: &[f32],
381 limit: usize,
382 filter: Option<SearchFilter>,
383 ) -> Result<Vec<SearchResult>, MemoryError> {
384 let limit_u64 = u64::try_from(limit)?;
385
386 let vector_filter = filter.as_ref().and_then(|f| {
387 let mut must = Vec::new();
388 if let Some(cid) = f.conversation_id {
389 must.push(FieldCondition {
390 field: "conversation_id".into(),
391 value: FieldValue::Integer(cid.0),
392 });
393 }
394 if let Some(ref role) = f.role {
395 must.push(FieldCondition {
396 field: "role".into(),
397 value: FieldValue::Text(role.clone()),
398 });
399 }
400 if let Some(ref category) = f.category {
401 must.push(FieldCondition {
402 field: "category".into(),
403 value: FieldValue::Text(category.clone()),
404 });
405 }
406 if must.is_empty() {
407 None
408 } else {
409 Some(VectorFilter {
410 must,
411 must_not: vec![],
412 })
413 }
414 });
415
416 let results = self
417 .ops
418 .search(
419 &self.collection,
420 query_vector.to_vec(),
421 limit_u64,
422 vector_filter,
423 )
424 .await?;
425
426 let mut best: std::collections::HashMap<MessageId, SearchResult> =
429 std::collections::HashMap::new();
430 for point in results {
431 let Some(message_id) = point
432 .payload
433 .get("message_id")
434 .and_then(serde_json::Value::as_i64)
435 else {
436 continue;
437 };
438 let Some(conversation_id) = point
439 .payload
440 .get("conversation_id")
441 .and_then(serde_json::Value::as_i64)
442 else {
443 continue;
444 };
445 let message_id = MessageId(message_id);
446 let entry = best.entry(message_id).or_insert(SearchResult {
447 message_id,
448 conversation_id: ConversationId(conversation_id),
449 score: f32::NEG_INFINITY,
450 });
451 if point.score > entry.score {
452 entry.score = point.score;
453 }
454 }
455
456 let mut search_results: Vec<SearchResult> = best.into_values().collect();
457 search_results.sort_by(|a, b| {
458 b.score
459 .partial_cmp(&a.score)
460 .unwrap_or(std::cmp::Ordering::Equal)
461 });
462 search_results.truncate(limit);
463
464 Ok(search_results)
465 }
466
467 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
473 self.ops.collection_exists(name).await.map_err(Into::into)
474 }
475
476 pub async fn ensure_named_collection(
482 &self,
483 name: &str,
484 vector_size: u64,
485 ) -> Result<(), MemoryError> {
486 self.ops.ensure_collection(name, vector_size).await?;
487 Ok(())
488 }
489
490 pub async fn store_to_collection(
498 &self,
499 collection: &str,
500 payload: serde_json::Value,
501 vector: Vec<f32>,
502 ) -> Result<String, MemoryError> {
503 let point_id = uuid::Uuid::new_v4().to_string();
504 let payload_map: std::collections::HashMap<String, serde_json::Value> =
505 serde_json::from_value(payload)?;
506 let point = VectorPoint {
507 id: point_id.clone(),
508 vector,
509 payload: payload_map,
510 };
511 self.ops.upsert(collection, vec![point]).await?;
512 Ok(point_id)
513 }
514
515 pub async fn upsert_to_collection(
523 &self,
524 collection: &str,
525 point_id: &str,
526 payload: serde_json::Value,
527 vector: Vec<f32>,
528 ) -> Result<(), MemoryError> {
529 let payload_map: std::collections::HashMap<String, serde_json::Value> =
530 serde_json::from_value(payload)?;
531 let point = VectorPoint {
532 id: point_id.to_owned(),
533 vector,
534 payload: payload_map,
535 };
536 self.ops.upsert(collection, vec![point]).await?;
537 Ok(())
538 }
539
540 pub async fn search_collection(
546 &self,
547 collection: &str,
548 query_vector: &[f32],
549 limit: usize,
550 filter: Option<VectorFilter>,
551 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
552 let limit_u64 = u64::try_from(limit)?;
553 let results = self
554 .ops
555 .search(collection, query_vector.to_vec(), limit_u64, filter)
556 .await?;
557 Ok(results)
558 }
559
560 pub async fn get_vectors_from_collection(
570 &self,
571 collection: &str,
572 point_ids: &[String],
573 ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
574 if point_ids.is_empty() {
575 return Ok(std::collections::HashMap::new());
576 }
577 match self.ops.get_points(collection, point_ids.to_vec()).await {
578 Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
579 Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
580 Err(e) => Err(MemoryError::VectorStore(e)),
581 }
582 }
583
584 pub async fn get_vectors(
592 &self,
593 ids: &[MessageId],
594 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
595 if ids.is_empty() {
596 return Ok(std::collections::HashMap::new());
597 }
598
599 let placeholders = zeph_db::placeholder_list(1, ids.len());
600 let query = format!(
601 "SELECT em.message_id, vp.vector \
602 FROM embeddings_metadata em \
603 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
604 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
605 );
606 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
607 for &id in ids {
608 q = q.bind(id);
609 }
610
611 let rows = q.fetch_all(&self.pool).await?;
612
613 let map = rows
614 .into_iter()
615 .filter_map(|(msg_id, blob)| {
616 if blob.len() % 4 != 0 {
617 return None;
618 }
619 let vec: Vec<f32> = blob
620 .chunks_exact(4)
621 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
622 .collect();
623 Some((msg_id, vec))
624 })
625 .collect();
626
627 Ok(map)
628 }
629
630 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
636 let row: (i64,) = zeph_db::query_as(sql!(
637 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
638 ))
639 .bind(message_id)
640 .fetch_one(&self.pool)
641 .await?;
642
643 Ok(row.0 > 0)
644 }
645
646 pub async fn is_epoch_current(
656 &self,
657 entity_name: &str,
658 qdrant_epoch: u64,
659 ) -> Result<bool, MemoryError> {
660 let row: Option<(i64,)> = zeph_db::query_as(sql!(
661 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
662 ))
663 .bind(entity_name)
664 .fetch_optional(&self.pool)
665 .await?;
666
667 match row {
668 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
670 }
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use crate::in_memory_store::InMemoryVectorStore;
678 use crate::store::SqliteStore;
679
680 async fn setup() -> (SqliteStore, DbPool) {
681 let store = SqliteStore::new(":memory:").await.unwrap();
682 let pool = store.pool().clone();
683 (store, pool)
684 }
685
686 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
687 let sqlite = SqliteStore::new(":memory:").await.unwrap();
688 let pool = sqlite.pool().clone();
689 let mem_store = Box::new(InMemoryVectorStore::new());
690 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
691 embedding_store.ensure_collection(4).await.unwrap();
693 (embedding_store, sqlite)
694 }
695
696 #[tokio::test]
697 async fn has_embedding_returns_false_when_none() {
698 let (_store, pool) = setup().await;
699
700 let row: (i64,) = zeph_db::query_as(sql!(
701 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
702 ))
703 .bind(999_i64)
704 .fetch_one(&pool)
705 .await
706 .unwrap();
707
708 assert_eq!(row.0, 0);
709 }
710
711 #[tokio::test]
712 async fn insert_and_query_embeddings_metadata() {
713 let (sqlite, pool) = setup().await;
714 let cid = sqlite.create_conversation().await.unwrap();
715 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
716
717 let point_id = uuid::Uuid::new_v4().to_string();
718 zeph_db::query(sql!(
719 "INSERT INTO embeddings_metadata \
720 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
721 VALUES (?, ?, ?, ?, ?)"
722 ))
723 .bind(msg_id)
724 .bind(0_i64)
725 .bind(&point_id)
726 .bind(768_i64)
727 .bind("qwen3-embedding")
728 .execute(&pool)
729 .await
730 .unwrap();
731
732 let row: (i64,) = zeph_db::query_as(sql!(
733 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
734 ))
735 .bind(msg_id)
736 .fetch_one(&pool)
737 .await
738 .unwrap();
739 assert_eq!(row.0, 1);
740 }
741
742 #[tokio::test]
743 async fn embedding_store_search_empty_returns_empty() {
744 let (store, _sqlite) = setup_with_store().await;
745 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
746 assert!(results.is_empty());
747 }
748
749 #[tokio::test]
750 async fn embedding_store_store_and_search() {
751 let (store, sqlite) = setup_with_store().await;
752 let cid = sqlite.create_conversation().await.unwrap();
753 let msg_id = sqlite
754 .save_message(cid, "user", "test message")
755 .await
756 .unwrap();
757
758 store
759 .store(
760 msg_id,
761 cid,
762 "user",
763 vec![1.0, 0.0, 0.0, 0.0],
764 MessageKind::Regular,
765 "test-model",
766 0,
767 )
768 .await
769 .unwrap();
770
771 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
772 assert_eq!(results.len(), 1);
773 assert_eq!(results[0].message_id, msg_id);
774 assert_eq!(results[0].conversation_id, cid);
775 assert!((results[0].score - 1.0).abs() < 0.001);
776 }
777
778 #[tokio::test]
779 async fn embedding_store_has_embedding_false_for_unknown() {
780 let (store, sqlite) = setup_with_store().await;
781 let cid = sqlite.create_conversation().await.unwrap();
782 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
783 assert!(!store.has_embedding(msg_id).await.unwrap());
784 }
785
786 #[tokio::test]
787 async fn embedding_store_has_embedding_true_after_store() {
788 let (store, sqlite) = setup_with_store().await;
789 let cid = sqlite.create_conversation().await.unwrap();
790 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
791
792 store
793 .store(
794 msg_id,
795 cid,
796 "user",
797 vec![0.0, 1.0, 0.0, 0.0],
798 MessageKind::Regular,
799 "test-model",
800 0,
801 )
802 .await
803 .unwrap();
804
805 assert!(store.has_embedding(msg_id).await.unwrap());
806 }
807
808 #[tokio::test]
809 async fn embedding_store_search_with_conversation_filter() {
810 let (store, sqlite) = setup_with_store().await;
811 let cid1 = sqlite.create_conversation().await.unwrap();
812 let cid2 = sqlite.create_conversation().await.unwrap();
813 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
814 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
815
816 store
817 .store(
818 msg1,
819 cid1,
820 "user",
821 vec![1.0, 0.0, 0.0, 0.0],
822 MessageKind::Regular,
823 "m",
824 0,
825 )
826 .await
827 .unwrap();
828 store
829 .store(
830 msg2,
831 cid2,
832 "user",
833 vec![1.0, 0.0, 0.0, 0.0],
834 MessageKind::Regular,
835 "m",
836 0,
837 )
838 .await
839 .unwrap();
840
841 let results = store
842 .search(
843 &[1.0, 0.0, 0.0, 0.0],
844 10,
845 Some(SearchFilter {
846 conversation_id: Some(cid1),
847 role: None,
848 category: None,
849 }),
850 )
851 .await
852 .unwrap();
853 assert_eq!(results.len(), 1);
854 assert_eq!(results[0].conversation_id, cid1);
855 }
856
857 #[tokio::test]
858 async fn unique_constraint_on_message_chunk_and_model() {
859 let (sqlite, pool) = setup().await;
860 let cid = sqlite.create_conversation().await.unwrap();
861 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
862
863 let point_id1 = uuid::Uuid::new_v4().to_string();
864 zeph_db::query(sql!(
865 "INSERT INTO embeddings_metadata \
866 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
867 VALUES (?, ?, ?, ?, ?)"
868 ))
869 .bind(msg_id)
870 .bind(0_i64)
871 .bind(&point_id1)
872 .bind(768_i64)
873 .bind("qwen3-embedding")
874 .execute(&pool)
875 .await
876 .unwrap();
877
878 let point_id2 = uuid::Uuid::new_v4().to_string();
880 let result = zeph_db::query(sql!(
881 "INSERT INTO embeddings_metadata \
882 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
883 VALUES (?, ?, ?, ?, ?)"
884 ))
885 .bind(msg_id)
886 .bind(0_i64)
887 .bind(&point_id2)
888 .bind(768_i64)
889 .bind("qwen3-embedding")
890 .execute(&pool)
891 .await;
892 assert!(result.is_err());
893
894 let point_id3 = uuid::Uuid::new_v4().to_string();
896 zeph_db::query(sql!(
897 "INSERT INTO embeddings_metadata \
898 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
899 VALUES (?, ?, ?, ?, ?)"
900 ))
901 .bind(msg_id)
902 .bind(1_i64)
903 .bind(&point_id3)
904 .bind(768_i64)
905 .bind("qwen3-embedding")
906 .execute(&pool)
907 .await
908 .unwrap();
909 }
910}