1pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30 Regular,
32 Summary,
34}
35
36impl MessageKind {
37 #[must_use]
38 pub fn is_summary(self) -> bool {
39 matches!(self, Self::Summary)
40 }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45pub async fn ensure_qdrant_collection(
53 ops: &QdrantOps,
54 collection: &str,
55 vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57 ops.ensure_collection(collection, vector_size).await
58}
59
60pub struct EmbeddingStore {
65 ops: Box<dyn VectorStore>,
66 collection: String,
67 pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("EmbeddingStore")
73 .field("collection", &self.collection)
74 .finish_non_exhaustive()
75 }
76}
77
78#[derive(Debug)]
80pub struct SearchFilter {
81 pub conversation_id: Option<ConversationId>,
83 pub role: Option<String>,
85 pub category: Option<String>,
88}
89
90#[derive(Debug)]
92pub struct SearchResult {
93 pub message_id: MessageId,
95 pub conversation_id: ConversationId,
97 pub score: f32,
99}
100
101impl EmbeddingStore {
102 pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
111 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
112
113 Ok(Self {
114 ops: Box::new(ops),
115 collection: COLLECTION_NAME.into(),
116 pool,
117 })
118 }
119
120 #[must_use]
124 pub fn new_sqlite(pool: DbPool) -> Self {
125 let ops = DbVectorStore::new(pool.clone());
126 Self {
127 ops: Box::new(ops),
128 collection: COLLECTION_NAME.into(),
129 pool,
130 }
131 }
132
133 #[must_use]
134 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
135 Self {
136 ops: store,
137 collection: COLLECTION_NAME.into(),
138 pool,
139 }
140 }
141
142 pub async fn health_check(&self) -> bool {
143 self.ops.health_check().await.unwrap_or(false)
144 }
145
146 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
154 self.ops
155 .ensure_collection(&self.collection, vector_size)
156 .await?;
157 self.ops
160 .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
161 .await?;
162 Ok(())
163 }
164
165 #[allow(clippy::too_many_arguments)]
175 pub async fn store_with_tool_context(
176 &self,
177 message_id: MessageId,
178 conversation_id: ConversationId,
179 role: &str,
180 vector: Vec<f32>,
181 kind: MessageKind,
182 model: &str,
183 chunk_index: u32,
184 tool_name: &str,
185 exit_code: Option<i32>,
186 timestamp: Option<&str>,
187 ) -> Result<String, MemoryError> {
188 let point_id = uuid::Uuid::new_v4().to_string();
189 let dimensions = i64::try_from(vector.len())?;
190
191 let mut payload = std::collections::HashMap::from([
192 ("message_id".to_owned(), serde_json::json!(message_id.0)),
193 (
194 "conversation_id".to_owned(),
195 serde_json::json!(conversation_id.0),
196 ),
197 ("role".to_owned(), serde_json::json!(role)),
198 (
199 "is_summary".to_owned(),
200 serde_json::json!(kind.is_summary()),
201 ),
202 ("tool_name".to_owned(), serde_json::json!(tool_name)),
203 ]);
204 if let Some(code) = exit_code {
205 payload.insert("exit_code".to_owned(), serde_json::json!(code));
206 }
207 if let Some(ts) = timestamp {
208 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
209 }
210
211 let point = VectorPoint {
212 id: point_id.clone(),
213 vector,
214 payload,
215 };
216
217 self.ops.upsert(&self.collection, vec![point]).await?;
218
219 let chunk_index_i64 = i64::from(chunk_index);
220 zeph_db::query(sql!(
221 "INSERT INTO embeddings_metadata \
222 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
223 VALUES (?, ?, ?, ?, ?) \
224 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
225 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
226 ))
227 .bind(message_id)
228 .bind(chunk_index_i64)
229 .bind(&point_id)
230 .bind(dimensions)
231 .bind(model)
232 .execute(&self.pool)
233 .await?;
234
235 Ok(point_id)
236 }
237
238 #[allow(clippy::too_many_arguments)]
249 pub async fn store(
250 &self,
251 message_id: MessageId,
252 conversation_id: ConversationId,
253 role: &str,
254 vector: Vec<f32>,
255 kind: MessageKind,
256 model: &str,
257 chunk_index: u32,
258 ) -> Result<String, MemoryError> {
259 let point_id = uuid::Uuid::new_v4().to_string();
260 let dimensions = i64::try_from(vector.len())?;
261
262 let payload = std::collections::HashMap::from([
263 ("message_id".to_owned(), serde_json::json!(message_id.0)),
264 (
265 "conversation_id".to_owned(),
266 serde_json::json!(conversation_id.0),
267 ),
268 ("role".to_owned(), serde_json::json!(role)),
269 (
270 "is_summary".to_owned(),
271 serde_json::json!(kind.is_summary()),
272 ),
273 ]);
274
275 let point = VectorPoint {
276 id: point_id.clone(),
277 vector,
278 payload,
279 };
280
281 self.ops.upsert(&self.collection, vec![point]).await?;
282
283 let chunk_index_i64 = i64::from(chunk_index);
284 zeph_db::query(sql!(
285 "INSERT INTO embeddings_metadata \
286 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
287 VALUES (?, ?, ?, ?, ?) \
288 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
289 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
290 ))
291 .bind(message_id)
292 .bind(chunk_index_i64)
293 .bind(&point_id)
294 .bind(dimensions)
295 .bind(model)
296 .execute(&self.pool)
297 .await?;
298
299 Ok(point_id)
300 }
301
302 #[allow(clippy::too_many_arguments)]
316 pub async fn store_with_category(
317 &self,
318 message_id: MessageId,
319 conversation_id: ConversationId,
320 role: &str,
321 vector: Vec<f32>,
322 kind: MessageKind,
323 model: &str,
324 chunk_index: u32,
325 category: Option<&str>,
326 ) -> Result<String, MemoryError> {
327 let point_id = uuid::Uuid::new_v4().to_string();
328 let dimensions = i64::try_from(vector.len())?;
329
330 let mut payload = std::collections::HashMap::from([
331 ("message_id".to_owned(), serde_json::json!(message_id.0)),
332 (
333 "conversation_id".to_owned(),
334 serde_json::json!(conversation_id.0),
335 ),
336 ("role".to_owned(), serde_json::json!(role)),
337 (
338 "is_summary".to_owned(),
339 serde_json::json!(kind.is_summary()),
340 ),
341 ]);
342 if let Some(cat) = category {
343 payload.insert("category".to_owned(), serde_json::json!(cat));
344 }
345
346 let point = VectorPoint {
347 id: point_id.clone(),
348 vector,
349 payload,
350 };
351
352 self.ops.upsert(&self.collection, vec![point]).await?;
353
354 let chunk_index_i64 = i64::from(chunk_index);
355 zeph_db::query(sql!(
356 "INSERT INTO embeddings_metadata \
357 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
358 VALUES (?, ?, ?, ?, ?) \
359 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
360 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
361 ))
362 .bind(message_id)
363 .bind(chunk_index_i64)
364 .bind(&point_id)
365 .bind(dimensions)
366 .bind(model)
367 .execute(&self.pool)
368 .await?;
369
370 Ok(point_id)
371 }
372
373 pub async fn search(
379 &self,
380 query_vector: &[f32],
381 limit: usize,
382 filter: Option<SearchFilter>,
383 ) -> Result<Vec<SearchResult>, MemoryError> {
384 let limit_u64 = u64::try_from(limit)?;
385
386 let vector_filter = filter.as_ref().and_then(|f| {
387 let mut must = Vec::new();
388 if let Some(cid) = f.conversation_id {
389 must.push(FieldCondition {
390 field: "conversation_id".into(),
391 value: FieldValue::Integer(cid.0),
392 });
393 }
394 if let Some(ref role) = f.role {
395 must.push(FieldCondition {
396 field: "role".into(),
397 value: FieldValue::Text(role.clone()),
398 });
399 }
400 if let Some(ref category) = f.category {
401 must.push(FieldCondition {
402 field: "category".into(),
403 value: FieldValue::Text(category.clone()),
404 });
405 }
406 if must.is_empty() {
407 None
408 } else {
409 Some(VectorFilter {
410 must,
411 must_not: vec![],
412 })
413 }
414 });
415
416 let results = self
417 .ops
418 .search(
419 &self.collection,
420 query_vector.to_vec(),
421 limit_u64,
422 vector_filter,
423 )
424 .await?;
425
426 let mut best: std::collections::HashMap<MessageId, SearchResult> =
429 std::collections::HashMap::new();
430 for point in results {
431 let Some(message_id) = point
432 .payload
433 .get("message_id")
434 .and_then(serde_json::Value::as_i64)
435 else {
436 continue;
437 };
438 let Some(conversation_id) = point
439 .payload
440 .get("conversation_id")
441 .and_then(serde_json::Value::as_i64)
442 else {
443 continue;
444 };
445 let message_id = MessageId(message_id);
446 let entry = best.entry(message_id).or_insert(SearchResult {
447 message_id,
448 conversation_id: ConversationId(conversation_id),
449 score: f32::NEG_INFINITY,
450 });
451 if point.score > entry.score {
452 entry.score = point.score;
453 }
454 }
455
456 let mut search_results: Vec<SearchResult> = best.into_values().collect();
457 search_results.sort_by(|a, b| {
458 b.score
459 .partial_cmp(&a.score)
460 .unwrap_or(std::cmp::Ordering::Equal)
461 });
462 search_results.truncate(limit);
463
464 Ok(search_results)
465 }
466
467 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
473 self.ops.collection_exists(name).await.map_err(Into::into)
474 }
475
476 pub async fn ensure_named_collection(
482 &self,
483 name: &str,
484 vector_size: u64,
485 ) -> Result<(), MemoryError> {
486 self.ops.ensure_collection(name, vector_size).await?;
487 Ok(())
488 }
489
490 pub async fn store_to_collection(
498 &self,
499 collection: &str,
500 payload: serde_json::Value,
501 vector: Vec<f32>,
502 ) -> Result<String, MemoryError> {
503 let point_id = uuid::Uuid::new_v4().to_string();
504 let payload_map: std::collections::HashMap<String, serde_json::Value> =
505 serde_json::from_value(payload)?;
506 let point = VectorPoint {
507 id: point_id.clone(),
508 vector,
509 payload: payload_map,
510 };
511 self.ops.upsert(collection, vec![point]).await?;
512 Ok(point_id)
513 }
514
515 pub async fn upsert_to_collection(
523 &self,
524 collection: &str,
525 point_id: &str,
526 payload: serde_json::Value,
527 vector: Vec<f32>,
528 ) -> Result<(), MemoryError> {
529 let payload_map: std::collections::HashMap<String, serde_json::Value> =
530 serde_json::from_value(payload)?;
531 let point = VectorPoint {
532 id: point_id.to_owned(),
533 vector,
534 payload: payload_map,
535 };
536 self.ops.upsert(collection, vec![point]).await?;
537 Ok(())
538 }
539
540 pub async fn search_collection(
546 &self,
547 collection: &str,
548 query_vector: &[f32],
549 limit: usize,
550 filter: Option<VectorFilter>,
551 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
552 let limit_u64 = u64::try_from(limit)?;
553 let results = self
554 .ops
555 .search(collection, query_vector.to_vec(), limit_u64, filter)
556 .await?;
557 Ok(results)
558 }
559
560 pub async fn get_vectors(
568 &self,
569 ids: &[MessageId],
570 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
571 if ids.is_empty() {
572 return Ok(std::collections::HashMap::new());
573 }
574
575 let placeholders = zeph_db::placeholder_list(1, ids.len());
576 let query = format!(
577 "SELECT em.message_id, vp.vector \
578 FROM embeddings_metadata em \
579 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
580 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
581 );
582 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
583 for &id in ids {
584 q = q.bind(id);
585 }
586
587 let rows = q.fetch_all(&self.pool).await?;
588
589 let map = rows
590 .into_iter()
591 .filter_map(|(msg_id, blob)| {
592 if blob.len() % 4 != 0 {
593 return None;
594 }
595 let vec: Vec<f32> = blob
596 .chunks_exact(4)
597 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
598 .collect();
599 Some((msg_id, vec))
600 })
601 .collect();
602
603 Ok(map)
604 }
605
606 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
612 let row: (i64,) = zeph_db::query_as(sql!(
613 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
614 ))
615 .bind(message_id)
616 .fetch_one(&self.pool)
617 .await?;
618
619 Ok(row.0 > 0)
620 }
621
622 pub async fn is_epoch_current(
632 &self,
633 entity_name: &str,
634 qdrant_epoch: u64,
635 ) -> Result<bool, MemoryError> {
636 let row: Option<(i64,)> = zeph_db::query_as(sql!(
637 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
638 ))
639 .bind(entity_name)
640 .fetch_optional(&self.pool)
641 .await?;
642
643 match row {
644 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
646 }
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::in_memory_store::InMemoryVectorStore;
654 use crate::store::SqliteStore;
655
656 async fn setup() -> (SqliteStore, DbPool) {
657 let store = SqliteStore::new(":memory:").await.unwrap();
658 let pool = store.pool().clone();
659 (store, pool)
660 }
661
662 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
663 let sqlite = SqliteStore::new(":memory:").await.unwrap();
664 let pool = sqlite.pool().clone();
665 let mem_store = Box::new(InMemoryVectorStore::new());
666 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
667 embedding_store.ensure_collection(4).await.unwrap();
669 (embedding_store, sqlite)
670 }
671
672 #[tokio::test]
673 async fn has_embedding_returns_false_when_none() {
674 let (_store, pool) = setup().await;
675
676 let row: (i64,) = zeph_db::query_as(sql!(
677 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
678 ))
679 .bind(999_i64)
680 .fetch_one(&pool)
681 .await
682 .unwrap();
683
684 assert_eq!(row.0, 0);
685 }
686
687 #[tokio::test]
688 async fn insert_and_query_embeddings_metadata() {
689 let (sqlite, pool) = setup().await;
690 let cid = sqlite.create_conversation().await.unwrap();
691 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
692
693 let point_id = uuid::Uuid::new_v4().to_string();
694 zeph_db::query(sql!(
695 "INSERT INTO embeddings_metadata \
696 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
697 VALUES (?, ?, ?, ?, ?)"
698 ))
699 .bind(msg_id)
700 .bind(0_i64)
701 .bind(&point_id)
702 .bind(768_i64)
703 .bind("qwen3-embedding")
704 .execute(&pool)
705 .await
706 .unwrap();
707
708 let row: (i64,) = zeph_db::query_as(sql!(
709 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
710 ))
711 .bind(msg_id)
712 .fetch_one(&pool)
713 .await
714 .unwrap();
715 assert_eq!(row.0, 1);
716 }
717
718 #[tokio::test]
719 async fn embedding_store_search_empty_returns_empty() {
720 let (store, _sqlite) = setup_with_store().await;
721 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
722 assert!(results.is_empty());
723 }
724
725 #[tokio::test]
726 async fn embedding_store_store_and_search() {
727 let (store, sqlite) = setup_with_store().await;
728 let cid = sqlite.create_conversation().await.unwrap();
729 let msg_id = sqlite
730 .save_message(cid, "user", "test message")
731 .await
732 .unwrap();
733
734 store
735 .store(
736 msg_id,
737 cid,
738 "user",
739 vec![1.0, 0.0, 0.0, 0.0],
740 MessageKind::Regular,
741 "test-model",
742 0,
743 )
744 .await
745 .unwrap();
746
747 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
748 assert_eq!(results.len(), 1);
749 assert_eq!(results[0].message_id, msg_id);
750 assert_eq!(results[0].conversation_id, cid);
751 assert!((results[0].score - 1.0).abs() < 0.001);
752 }
753
754 #[tokio::test]
755 async fn embedding_store_has_embedding_false_for_unknown() {
756 let (store, sqlite) = setup_with_store().await;
757 let cid = sqlite.create_conversation().await.unwrap();
758 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
759 assert!(!store.has_embedding(msg_id).await.unwrap());
760 }
761
762 #[tokio::test]
763 async fn embedding_store_has_embedding_true_after_store() {
764 let (store, sqlite) = setup_with_store().await;
765 let cid = sqlite.create_conversation().await.unwrap();
766 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
767
768 store
769 .store(
770 msg_id,
771 cid,
772 "user",
773 vec![0.0, 1.0, 0.0, 0.0],
774 MessageKind::Regular,
775 "test-model",
776 0,
777 )
778 .await
779 .unwrap();
780
781 assert!(store.has_embedding(msg_id).await.unwrap());
782 }
783
784 #[tokio::test]
785 async fn embedding_store_search_with_conversation_filter() {
786 let (store, sqlite) = setup_with_store().await;
787 let cid1 = sqlite.create_conversation().await.unwrap();
788 let cid2 = sqlite.create_conversation().await.unwrap();
789 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
790 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
791
792 store
793 .store(
794 msg1,
795 cid1,
796 "user",
797 vec![1.0, 0.0, 0.0, 0.0],
798 MessageKind::Regular,
799 "m",
800 0,
801 )
802 .await
803 .unwrap();
804 store
805 .store(
806 msg2,
807 cid2,
808 "user",
809 vec![1.0, 0.0, 0.0, 0.0],
810 MessageKind::Regular,
811 "m",
812 0,
813 )
814 .await
815 .unwrap();
816
817 let results = store
818 .search(
819 &[1.0, 0.0, 0.0, 0.0],
820 10,
821 Some(SearchFilter {
822 conversation_id: Some(cid1),
823 role: None,
824 category: None,
825 }),
826 )
827 .await
828 .unwrap();
829 assert_eq!(results.len(), 1);
830 assert_eq!(results[0].conversation_id, cid1);
831 }
832
833 #[tokio::test]
834 async fn unique_constraint_on_message_chunk_and_model() {
835 let (sqlite, pool) = setup().await;
836 let cid = sqlite.create_conversation().await.unwrap();
837 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
838
839 let point_id1 = uuid::Uuid::new_v4().to_string();
840 zeph_db::query(sql!(
841 "INSERT INTO embeddings_metadata \
842 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
843 VALUES (?, ?, ?, ?, ?)"
844 ))
845 .bind(msg_id)
846 .bind(0_i64)
847 .bind(&point_id1)
848 .bind(768_i64)
849 .bind("qwen3-embedding")
850 .execute(&pool)
851 .await
852 .unwrap();
853
854 let point_id2 = uuid::Uuid::new_v4().to_string();
856 let result = zeph_db::query(sql!(
857 "INSERT INTO embeddings_metadata \
858 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
859 VALUES (?, ?, ?, ?, ?)"
860 ))
861 .bind(msg_id)
862 .bind(0_i64)
863 .bind(&point_id2)
864 .bind(768_i64)
865 .bind("qwen3-embedding")
866 .execute(&pool)
867 .await;
868 assert!(result.is_err());
869
870 let point_id3 = uuid::Uuid::new_v4().to_string();
872 zeph_db::query(sql!(
873 "INSERT INTO embeddings_metadata \
874 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
875 VALUES (?, ?, ?, ?, ?)"
876 ))
877 .bind(msg_id)
878 .bind(1_i64)
879 .bind(&point_id3)
880 .bind(768_i64)
881 .bind("qwen3-embedding")
882 .execute(&pool)
883 .await
884 .unwrap();
885 }
886}