1pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18 Regular,
19 Summary,
20}
21
22impl MessageKind {
23 #[must_use]
24 pub fn is_summary(self) -> bool {
25 matches!(self, Self::Summary)
26 }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31pub async fn ensure_qdrant_collection(
39 ops: &QdrantOps,
40 collection: &str,
41 vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43 ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47 ops: Box<dyn VectorStore>,
48 collection: String,
49 pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("EmbeddingStore")
55 .field("collection", &self.collection)
56 .finish_non_exhaustive()
57 }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62 pub conversation_id: Option<ConversationId>,
63 pub role: Option<String>,
64 pub category: Option<String>,
67}
68
69#[derive(Debug)]
70pub struct SearchResult {
71 pub message_id: MessageId,
72 pub conversation_id: ConversationId,
73 pub score: f32,
74}
75
76impl EmbeddingStore {
77 pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
86 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
87
88 Ok(Self {
89 ops: Box::new(ops),
90 collection: COLLECTION_NAME.into(),
91 pool,
92 })
93 }
94
95 #[must_use]
99 pub fn new_sqlite(pool: DbPool) -> Self {
100 let ops = DbVectorStore::new(pool.clone());
101 Self {
102 ops: Box::new(ops),
103 collection: COLLECTION_NAME.into(),
104 pool,
105 }
106 }
107
108 #[must_use]
109 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
110 Self {
111 ops: store,
112 collection: COLLECTION_NAME.into(),
113 pool,
114 }
115 }
116
117 pub async fn health_check(&self) -> bool {
118 self.ops.health_check().await.unwrap_or(false)
119 }
120
121 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
129 self.ops
130 .ensure_collection(&self.collection, vector_size)
131 .await?;
132 self.ops
135 .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
136 .await?;
137 Ok(())
138 }
139
140 #[allow(clippy::too_many_arguments)]
150 pub async fn store_with_tool_context(
151 &self,
152 message_id: MessageId,
153 conversation_id: ConversationId,
154 role: &str,
155 vector: Vec<f32>,
156 kind: MessageKind,
157 model: &str,
158 chunk_index: u32,
159 tool_name: &str,
160 exit_code: Option<i32>,
161 timestamp: Option<&str>,
162 ) -> Result<String, MemoryError> {
163 let point_id = uuid::Uuid::new_v4().to_string();
164 let dimensions = i64::try_from(vector.len())?;
165
166 let mut payload = std::collections::HashMap::from([
167 ("message_id".to_owned(), serde_json::json!(message_id.0)),
168 (
169 "conversation_id".to_owned(),
170 serde_json::json!(conversation_id.0),
171 ),
172 ("role".to_owned(), serde_json::json!(role)),
173 (
174 "is_summary".to_owned(),
175 serde_json::json!(kind.is_summary()),
176 ),
177 ("tool_name".to_owned(), serde_json::json!(tool_name)),
178 ]);
179 if let Some(code) = exit_code {
180 payload.insert("exit_code".to_owned(), serde_json::json!(code));
181 }
182 if let Some(ts) = timestamp {
183 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
184 }
185
186 let point = VectorPoint {
187 id: point_id.clone(),
188 vector,
189 payload,
190 };
191
192 self.ops.upsert(&self.collection, vec![point]).await?;
193
194 let chunk_index_i64 = i64::from(chunk_index);
195 zeph_db::query(sql!(
196 "INSERT INTO embeddings_metadata \
197 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
198 VALUES (?, ?, ?, ?, ?) \
199 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
200 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
201 ))
202 .bind(message_id)
203 .bind(chunk_index_i64)
204 .bind(&point_id)
205 .bind(dimensions)
206 .bind(model)
207 .execute(&self.pool)
208 .await?;
209
210 Ok(point_id)
211 }
212
213 #[allow(clippy::too_many_arguments)]
224 pub async fn store(
225 &self,
226 message_id: MessageId,
227 conversation_id: ConversationId,
228 role: &str,
229 vector: Vec<f32>,
230 kind: MessageKind,
231 model: &str,
232 chunk_index: u32,
233 ) -> Result<String, MemoryError> {
234 let point_id = uuid::Uuid::new_v4().to_string();
235 let dimensions = i64::try_from(vector.len())?;
236
237 let payload = std::collections::HashMap::from([
238 ("message_id".to_owned(), serde_json::json!(message_id.0)),
239 (
240 "conversation_id".to_owned(),
241 serde_json::json!(conversation_id.0),
242 ),
243 ("role".to_owned(), serde_json::json!(role)),
244 (
245 "is_summary".to_owned(),
246 serde_json::json!(kind.is_summary()),
247 ),
248 ]);
249
250 let point = VectorPoint {
251 id: point_id.clone(),
252 vector,
253 payload,
254 };
255
256 self.ops.upsert(&self.collection, vec![point]).await?;
257
258 let chunk_index_i64 = i64::from(chunk_index);
259 zeph_db::query(sql!(
260 "INSERT INTO embeddings_metadata \
261 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
262 VALUES (?, ?, ?, ?, ?) \
263 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
264 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
265 ))
266 .bind(message_id)
267 .bind(chunk_index_i64)
268 .bind(&point_id)
269 .bind(dimensions)
270 .bind(model)
271 .execute(&self.pool)
272 .await?;
273
274 Ok(point_id)
275 }
276
277 #[allow(clippy::too_many_arguments)]
291 pub async fn store_with_category(
292 &self,
293 message_id: MessageId,
294 conversation_id: ConversationId,
295 role: &str,
296 vector: Vec<f32>,
297 kind: MessageKind,
298 model: &str,
299 chunk_index: u32,
300 category: Option<&str>,
301 ) -> Result<String, MemoryError> {
302 let point_id = uuid::Uuid::new_v4().to_string();
303 let dimensions = i64::try_from(vector.len())?;
304
305 let mut payload = std::collections::HashMap::from([
306 ("message_id".to_owned(), serde_json::json!(message_id.0)),
307 (
308 "conversation_id".to_owned(),
309 serde_json::json!(conversation_id.0),
310 ),
311 ("role".to_owned(), serde_json::json!(role)),
312 (
313 "is_summary".to_owned(),
314 serde_json::json!(kind.is_summary()),
315 ),
316 ]);
317 if let Some(cat) = category {
318 payload.insert("category".to_owned(), serde_json::json!(cat));
319 }
320
321 let point = VectorPoint {
322 id: point_id.clone(),
323 vector,
324 payload,
325 };
326
327 self.ops.upsert(&self.collection, vec![point]).await?;
328
329 let chunk_index_i64 = i64::from(chunk_index);
330 zeph_db::query(sql!(
331 "INSERT INTO embeddings_metadata \
332 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
333 VALUES (?, ?, ?, ?, ?) \
334 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
335 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
336 ))
337 .bind(message_id)
338 .bind(chunk_index_i64)
339 .bind(&point_id)
340 .bind(dimensions)
341 .bind(model)
342 .execute(&self.pool)
343 .await?;
344
345 Ok(point_id)
346 }
347
348 pub async fn search(
354 &self,
355 query_vector: &[f32],
356 limit: usize,
357 filter: Option<SearchFilter>,
358 ) -> Result<Vec<SearchResult>, MemoryError> {
359 let limit_u64 = u64::try_from(limit)?;
360
361 let vector_filter = filter.as_ref().and_then(|f| {
362 let mut must = Vec::new();
363 if let Some(cid) = f.conversation_id {
364 must.push(FieldCondition {
365 field: "conversation_id".into(),
366 value: FieldValue::Integer(cid.0),
367 });
368 }
369 if let Some(ref role) = f.role {
370 must.push(FieldCondition {
371 field: "role".into(),
372 value: FieldValue::Text(role.clone()),
373 });
374 }
375 if let Some(ref category) = f.category {
376 must.push(FieldCondition {
377 field: "category".into(),
378 value: FieldValue::Text(category.clone()),
379 });
380 }
381 if must.is_empty() {
382 None
383 } else {
384 Some(VectorFilter {
385 must,
386 must_not: vec![],
387 })
388 }
389 });
390
391 let results = self
392 .ops
393 .search(
394 &self.collection,
395 query_vector.to_vec(),
396 limit_u64,
397 vector_filter,
398 )
399 .await?;
400
401 let mut best: std::collections::HashMap<MessageId, SearchResult> =
404 std::collections::HashMap::new();
405 for point in results {
406 let Some(message_id) = point
407 .payload
408 .get("message_id")
409 .and_then(serde_json::Value::as_i64)
410 else {
411 continue;
412 };
413 let Some(conversation_id) = point
414 .payload
415 .get("conversation_id")
416 .and_then(serde_json::Value::as_i64)
417 else {
418 continue;
419 };
420 let message_id = MessageId(message_id);
421 let entry = best.entry(message_id).or_insert(SearchResult {
422 message_id,
423 conversation_id: ConversationId(conversation_id),
424 score: f32::NEG_INFINITY,
425 });
426 if point.score > entry.score {
427 entry.score = point.score;
428 }
429 }
430
431 let mut search_results: Vec<SearchResult> = best.into_values().collect();
432 search_results.sort_by(|a, b| {
433 b.score
434 .partial_cmp(&a.score)
435 .unwrap_or(std::cmp::Ordering::Equal)
436 });
437 search_results.truncate(limit);
438
439 Ok(search_results)
440 }
441
442 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
448 self.ops.collection_exists(name).await.map_err(Into::into)
449 }
450
451 pub async fn ensure_named_collection(
457 &self,
458 name: &str,
459 vector_size: u64,
460 ) -> Result<(), MemoryError> {
461 self.ops.ensure_collection(name, vector_size).await?;
462 Ok(())
463 }
464
465 pub async fn store_to_collection(
473 &self,
474 collection: &str,
475 payload: serde_json::Value,
476 vector: Vec<f32>,
477 ) -> Result<String, MemoryError> {
478 let point_id = uuid::Uuid::new_v4().to_string();
479 let payload_map: std::collections::HashMap<String, serde_json::Value> =
480 serde_json::from_value(payload)?;
481 let point = VectorPoint {
482 id: point_id.clone(),
483 vector,
484 payload: payload_map,
485 };
486 self.ops.upsert(collection, vec![point]).await?;
487 Ok(point_id)
488 }
489
490 pub async fn upsert_to_collection(
498 &self,
499 collection: &str,
500 point_id: &str,
501 payload: serde_json::Value,
502 vector: Vec<f32>,
503 ) -> Result<(), MemoryError> {
504 let payload_map: std::collections::HashMap<String, serde_json::Value> =
505 serde_json::from_value(payload)?;
506 let point = VectorPoint {
507 id: point_id.to_owned(),
508 vector,
509 payload: payload_map,
510 };
511 self.ops.upsert(collection, vec![point]).await?;
512 Ok(())
513 }
514
515 pub async fn search_collection(
521 &self,
522 collection: &str,
523 query_vector: &[f32],
524 limit: usize,
525 filter: Option<VectorFilter>,
526 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
527 let limit_u64 = u64::try_from(limit)?;
528 let results = self
529 .ops
530 .search(collection, query_vector.to_vec(), limit_u64, filter)
531 .await?;
532 Ok(results)
533 }
534
535 pub async fn get_vectors(
543 &self,
544 ids: &[MessageId],
545 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
546 if ids.is_empty() {
547 return Ok(std::collections::HashMap::new());
548 }
549
550 let placeholders = zeph_db::placeholder_list(1, ids.len());
551 let query = format!(
552 "SELECT em.message_id, vp.vector \
553 FROM embeddings_metadata em \
554 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
555 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
556 );
557 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
558 for &id in ids {
559 q = q.bind(id);
560 }
561
562 let rows = q.fetch_all(&self.pool).await?;
563
564 let map = rows
565 .into_iter()
566 .filter_map(|(msg_id, blob)| {
567 if blob.len() % 4 != 0 {
568 return None;
569 }
570 let vec: Vec<f32> = blob
571 .chunks_exact(4)
572 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
573 .collect();
574 Some((msg_id, vec))
575 })
576 .collect();
577
578 Ok(map)
579 }
580
581 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
587 let row: (i64,) = zeph_db::query_as(sql!(
588 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
589 ))
590 .bind(message_id)
591 .fetch_one(&self.pool)
592 .await?;
593
594 Ok(row.0 > 0)
595 }
596
597 pub async fn is_epoch_current(
607 &self,
608 entity_name: &str,
609 qdrant_epoch: u64,
610 ) -> Result<bool, MemoryError> {
611 let row: Option<(i64,)> = zeph_db::query_as(sql!(
612 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
613 ))
614 .bind(entity_name)
615 .fetch_optional(&self.pool)
616 .await?;
617
618 match row {
619 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
621 }
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use crate::in_memory_store::InMemoryVectorStore;
629 use crate::store::SqliteStore;
630
631 async fn setup() -> (SqliteStore, DbPool) {
632 let store = SqliteStore::new(":memory:").await.unwrap();
633 let pool = store.pool().clone();
634 (store, pool)
635 }
636
637 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
638 let sqlite = SqliteStore::new(":memory:").await.unwrap();
639 let pool = sqlite.pool().clone();
640 let mem_store = Box::new(InMemoryVectorStore::new());
641 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
642 embedding_store.ensure_collection(4).await.unwrap();
644 (embedding_store, sqlite)
645 }
646
647 #[tokio::test]
648 async fn has_embedding_returns_false_when_none() {
649 let (_store, pool) = setup().await;
650
651 let row: (i64,) = zeph_db::query_as(sql!(
652 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
653 ))
654 .bind(999_i64)
655 .fetch_one(&pool)
656 .await
657 .unwrap();
658
659 assert_eq!(row.0, 0);
660 }
661
662 #[tokio::test]
663 async fn insert_and_query_embeddings_metadata() {
664 let (sqlite, pool) = setup().await;
665 let cid = sqlite.create_conversation().await.unwrap();
666 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
667
668 let point_id = uuid::Uuid::new_v4().to_string();
669 zeph_db::query(sql!(
670 "INSERT INTO embeddings_metadata \
671 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
672 VALUES (?, ?, ?, ?, ?)"
673 ))
674 .bind(msg_id)
675 .bind(0_i64)
676 .bind(&point_id)
677 .bind(768_i64)
678 .bind("qwen3-embedding")
679 .execute(&pool)
680 .await
681 .unwrap();
682
683 let row: (i64,) = zeph_db::query_as(sql!(
684 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
685 ))
686 .bind(msg_id)
687 .fetch_one(&pool)
688 .await
689 .unwrap();
690 assert_eq!(row.0, 1);
691 }
692
693 #[tokio::test]
694 async fn embedding_store_search_empty_returns_empty() {
695 let (store, _sqlite) = setup_with_store().await;
696 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
697 assert!(results.is_empty());
698 }
699
700 #[tokio::test]
701 async fn embedding_store_store_and_search() {
702 let (store, sqlite) = setup_with_store().await;
703 let cid = sqlite.create_conversation().await.unwrap();
704 let msg_id = sqlite
705 .save_message(cid, "user", "test message")
706 .await
707 .unwrap();
708
709 store
710 .store(
711 msg_id,
712 cid,
713 "user",
714 vec![1.0, 0.0, 0.0, 0.0],
715 MessageKind::Regular,
716 "test-model",
717 0,
718 )
719 .await
720 .unwrap();
721
722 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
723 assert_eq!(results.len(), 1);
724 assert_eq!(results[0].message_id, msg_id);
725 assert_eq!(results[0].conversation_id, cid);
726 assert!((results[0].score - 1.0).abs() < 0.001);
727 }
728
729 #[tokio::test]
730 async fn embedding_store_has_embedding_false_for_unknown() {
731 let (store, sqlite) = setup_with_store().await;
732 let cid = sqlite.create_conversation().await.unwrap();
733 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
734 assert!(!store.has_embedding(msg_id).await.unwrap());
735 }
736
737 #[tokio::test]
738 async fn embedding_store_has_embedding_true_after_store() {
739 let (store, sqlite) = setup_with_store().await;
740 let cid = sqlite.create_conversation().await.unwrap();
741 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
742
743 store
744 .store(
745 msg_id,
746 cid,
747 "user",
748 vec![0.0, 1.0, 0.0, 0.0],
749 MessageKind::Regular,
750 "test-model",
751 0,
752 )
753 .await
754 .unwrap();
755
756 assert!(store.has_embedding(msg_id).await.unwrap());
757 }
758
759 #[tokio::test]
760 async fn embedding_store_search_with_conversation_filter() {
761 let (store, sqlite) = setup_with_store().await;
762 let cid1 = sqlite.create_conversation().await.unwrap();
763 let cid2 = sqlite.create_conversation().await.unwrap();
764 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
765 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
766
767 store
768 .store(
769 msg1,
770 cid1,
771 "user",
772 vec![1.0, 0.0, 0.0, 0.0],
773 MessageKind::Regular,
774 "m",
775 0,
776 )
777 .await
778 .unwrap();
779 store
780 .store(
781 msg2,
782 cid2,
783 "user",
784 vec![1.0, 0.0, 0.0, 0.0],
785 MessageKind::Regular,
786 "m",
787 0,
788 )
789 .await
790 .unwrap();
791
792 let results = store
793 .search(
794 &[1.0, 0.0, 0.0, 0.0],
795 10,
796 Some(SearchFilter {
797 conversation_id: Some(cid1),
798 role: None,
799 category: None,
800 }),
801 )
802 .await
803 .unwrap();
804 assert_eq!(results.len(), 1);
805 assert_eq!(results[0].conversation_id, cid1);
806 }
807
808 #[tokio::test]
809 async fn unique_constraint_on_message_chunk_and_model() {
810 let (sqlite, pool) = setup().await;
811 let cid = sqlite.create_conversation().await.unwrap();
812 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
813
814 let point_id1 = uuid::Uuid::new_v4().to_string();
815 zeph_db::query(sql!(
816 "INSERT INTO embeddings_metadata \
817 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
818 VALUES (?, ?, ?, ?, ?)"
819 ))
820 .bind(msg_id)
821 .bind(0_i64)
822 .bind(&point_id1)
823 .bind(768_i64)
824 .bind("qwen3-embedding")
825 .execute(&pool)
826 .await
827 .unwrap();
828
829 let point_id2 = uuid::Uuid::new_v4().to_string();
831 let result = zeph_db::query(sql!(
832 "INSERT INTO embeddings_metadata \
833 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
834 VALUES (?, ?, ?, ?, ?)"
835 ))
836 .bind(msg_id)
837 .bind(0_i64)
838 .bind(&point_id2)
839 .bind(768_i64)
840 .bind("qwen3-embedding")
841 .execute(&pool)
842 .await;
843 assert!(result.is_err());
844
845 let point_id3 = uuid::Uuid::new_v4().to_string();
847 zeph_db::query(sql!(
848 "INSERT INTO embeddings_metadata \
849 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
850 VALUES (?, ?, ?, ?, ?)"
851 ))
852 .bind(msg_id)
853 .bind(1_i64)
854 .bind(&point_id3)
855 .bind(768_i64)
856 .bind("qwen3-embedding")
857 .execute(&pool)
858 .await
859 .unwrap();
860 }
861}