1pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18 Regular,
19 Summary,
20}
21
22impl MessageKind {
23 #[must_use]
24 pub fn is_summary(self) -> bool {
25 matches!(self, Self::Summary)
26 }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31pub async fn ensure_qdrant_collection(
39 ops: &QdrantOps,
40 collection: &str,
41 vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43 ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47 ops: Box<dyn VectorStore>,
48 collection: String,
49 pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("EmbeddingStore")
55 .field("collection", &self.collection)
56 .finish_non_exhaustive()
57 }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62 pub conversation_id: Option<ConversationId>,
63 pub role: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct SearchResult {
68 pub message_id: MessageId,
69 pub conversation_id: ConversationId,
70 pub score: f32,
71}
72
73impl EmbeddingStore {
74 pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
83 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
84
85 Ok(Self {
86 ops: Box::new(ops),
87 collection: COLLECTION_NAME.into(),
88 pool,
89 })
90 }
91
92 #[must_use]
96 pub fn new_sqlite(pool: DbPool) -> Self {
97 let ops = DbVectorStore::new(pool.clone());
98 Self {
99 ops: Box::new(ops),
100 collection: COLLECTION_NAME.into(),
101 pool,
102 }
103 }
104
105 #[must_use]
106 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
107 Self {
108 ops: store,
109 collection: COLLECTION_NAME.into(),
110 pool,
111 }
112 }
113
114 pub async fn health_check(&self) -> bool {
115 self.ops.health_check().await.unwrap_or(false)
116 }
117
118 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
126 self.ops
127 .ensure_collection(&self.collection, vector_size)
128 .await?;
129 Ok(())
130 }
131
132 #[allow(clippy::too_many_arguments)]
142 pub async fn store_with_tool_context(
143 &self,
144 message_id: MessageId,
145 conversation_id: ConversationId,
146 role: &str,
147 vector: Vec<f32>,
148 kind: MessageKind,
149 model: &str,
150 chunk_index: u32,
151 tool_name: &str,
152 exit_code: Option<i32>,
153 timestamp: Option<&str>,
154 ) -> Result<String, MemoryError> {
155 let point_id = uuid::Uuid::new_v4().to_string();
156 let dimensions = i64::try_from(vector.len())?;
157
158 let mut payload = std::collections::HashMap::from([
159 ("message_id".to_owned(), serde_json::json!(message_id.0)),
160 (
161 "conversation_id".to_owned(),
162 serde_json::json!(conversation_id.0),
163 ),
164 ("role".to_owned(), serde_json::json!(role)),
165 (
166 "is_summary".to_owned(),
167 serde_json::json!(kind.is_summary()),
168 ),
169 ("tool_name".to_owned(), serde_json::json!(tool_name)),
170 ]);
171 if let Some(code) = exit_code {
172 payload.insert("exit_code".to_owned(), serde_json::json!(code));
173 }
174 if let Some(ts) = timestamp {
175 payload.insert("timestamp".to_owned(), serde_json::json!(ts));
176 }
177
178 let point = VectorPoint {
179 id: point_id.clone(),
180 vector,
181 payload,
182 };
183
184 self.ops.upsert(&self.collection, vec![point]).await?;
185
186 let chunk_index_i64 = i64::from(chunk_index);
187 zeph_db::query(sql!(
188 "INSERT INTO embeddings_metadata \
189 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
190 VALUES (?, ?, ?, ?, ?) \
191 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
192 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
193 ))
194 .bind(message_id)
195 .bind(chunk_index_i64)
196 .bind(&point_id)
197 .bind(dimensions)
198 .bind(model)
199 .execute(&self.pool)
200 .await?;
201
202 Ok(point_id)
203 }
204
205 #[allow(clippy::too_many_arguments)]
216 pub async fn store(
217 &self,
218 message_id: MessageId,
219 conversation_id: ConversationId,
220 role: &str,
221 vector: Vec<f32>,
222 kind: MessageKind,
223 model: &str,
224 chunk_index: u32,
225 ) -> Result<String, MemoryError> {
226 let point_id = uuid::Uuid::new_v4().to_string();
227 let dimensions = i64::try_from(vector.len())?;
228
229 let payload = std::collections::HashMap::from([
230 ("message_id".to_owned(), serde_json::json!(message_id.0)),
231 (
232 "conversation_id".to_owned(),
233 serde_json::json!(conversation_id.0),
234 ),
235 ("role".to_owned(), serde_json::json!(role)),
236 (
237 "is_summary".to_owned(),
238 serde_json::json!(kind.is_summary()),
239 ),
240 ]);
241
242 let point = VectorPoint {
243 id: point_id.clone(),
244 vector,
245 payload,
246 };
247
248 self.ops.upsert(&self.collection, vec![point]).await?;
249
250 let chunk_index_i64 = i64::from(chunk_index);
251 zeph_db::query(sql!(
252 "INSERT INTO embeddings_metadata \
253 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
254 VALUES (?, ?, ?, ?, ?) \
255 ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
256 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
257 ))
258 .bind(message_id)
259 .bind(chunk_index_i64)
260 .bind(&point_id)
261 .bind(dimensions)
262 .bind(model)
263 .execute(&self.pool)
264 .await?;
265
266 Ok(point_id)
267 }
268
269 pub async fn search(
275 &self,
276 query_vector: &[f32],
277 limit: usize,
278 filter: Option<SearchFilter>,
279 ) -> Result<Vec<SearchResult>, MemoryError> {
280 let limit_u64 = u64::try_from(limit)?;
281
282 let vector_filter = filter.as_ref().and_then(|f| {
283 let mut must = Vec::new();
284 if let Some(cid) = f.conversation_id {
285 must.push(FieldCondition {
286 field: "conversation_id".into(),
287 value: FieldValue::Integer(cid.0),
288 });
289 }
290 if let Some(ref role) = f.role {
291 must.push(FieldCondition {
292 field: "role".into(),
293 value: FieldValue::Text(role.clone()),
294 });
295 }
296 if must.is_empty() {
297 None
298 } else {
299 Some(VectorFilter {
300 must,
301 must_not: vec![],
302 })
303 }
304 });
305
306 let results = self
307 .ops
308 .search(
309 &self.collection,
310 query_vector.to_vec(),
311 limit_u64,
312 vector_filter,
313 )
314 .await?;
315
316 let mut best: std::collections::HashMap<MessageId, SearchResult> =
319 std::collections::HashMap::new();
320 for point in results {
321 let Some(message_id) = point
322 .payload
323 .get("message_id")
324 .and_then(serde_json::Value::as_i64)
325 else {
326 continue;
327 };
328 let Some(conversation_id) = point
329 .payload
330 .get("conversation_id")
331 .and_then(serde_json::Value::as_i64)
332 else {
333 continue;
334 };
335 let message_id = MessageId(message_id);
336 let entry = best.entry(message_id).or_insert(SearchResult {
337 message_id,
338 conversation_id: ConversationId(conversation_id),
339 score: f32::NEG_INFINITY,
340 });
341 if point.score > entry.score {
342 entry.score = point.score;
343 }
344 }
345
346 let mut search_results: Vec<SearchResult> = best.into_values().collect();
347 search_results.sort_by(|a, b| {
348 b.score
349 .partial_cmp(&a.score)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 });
352 search_results.truncate(limit);
353
354 Ok(search_results)
355 }
356
357 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
363 self.ops.collection_exists(name).await.map_err(Into::into)
364 }
365
366 pub async fn ensure_named_collection(
372 &self,
373 name: &str,
374 vector_size: u64,
375 ) -> Result<(), MemoryError> {
376 self.ops.ensure_collection(name, vector_size).await?;
377 Ok(())
378 }
379
380 pub async fn store_to_collection(
388 &self,
389 collection: &str,
390 payload: serde_json::Value,
391 vector: Vec<f32>,
392 ) -> Result<String, MemoryError> {
393 let point_id = uuid::Uuid::new_v4().to_string();
394 let payload_map: std::collections::HashMap<String, serde_json::Value> =
395 serde_json::from_value(payload)?;
396 let point = VectorPoint {
397 id: point_id.clone(),
398 vector,
399 payload: payload_map,
400 };
401 self.ops.upsert(collection, vec![point]).await?;
402 Ok(point_id)
403 }
404
405 pub async fn upsert_to_collection(
413 &self,
414 collection: &str,
415 point_id: &str,
416 payload: serde_json::Value,
417 vector: Vec<f32>,
418 ) -> Result<(), MemoryError> {
419 let payload_map: std::collections::HashMap<String, serde_json::Value> =
420 serde_json::from_value(payload)?;
421 let point = VectorPoint {
422 id: point_id.to_owned(),
423 vector,
424 payload: payload_map,
425 };
426 self.ops.upsert(collection, vec![point]).await?;
427 Ok(())
428 }
429
430 pub async fn search_collection(
436 &self,
437 collection: &str,
438 query_vector: &[f32],
439 limit: usize,
440 filter: Option<VectorFilter>,
441 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
442 let limit_u64 = u64::try_from(limit)?;
443 let results = self
444 .ops
445 .search(collection, query_vector.to_vec(), limit_u64, filter)
446 .await?;
447 Ok(results)
448 }
449
450 pub async fn get_vectors(
458 &self,
459 ids: &[MessageId],
460 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
461 if ids.is_empty() {
462 return Ok(std::collections::HashMap::new());
463 }
464
465 let placeholders = zeph_db::placeholder_list(1, ids.len());
466 let query = format!(
467 "SELECT em.message_id, vp.vector \
468 FROM embeddings_metadata em \
469 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
470 WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
471 );
472 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
473 for &id in ids {
474 q = q.bind(id);
475 }
476
477 let rows = q.fetch_all(&self.pool).await?;
478
479 let map = rows
480 .into_iter()
481 .filter_map(|(msg_id, blob)| {
482 if blob.len() % 4 != 0 {
483 return None;
484 }
485 let vec: Vec<f32> = blob
486 .chunks_exact(4)
487 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
488 .collect();
489 Some((msg_id, vec))
490 })
491 .collect();
492
493 Ok(map)
494 }
495
496 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
502 let row: (i64,) = zeph_db::query_as(sql!(
503 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
504 ))
505 .bind(message_id)
506 .fetch_one(&self.pool)
507 .await?;
508
509 Ok(row.0 > 0)
510 }
511
512 pub async fn is_epoch_current(
522 &self,
523 entity_name: &str,
524 qdrant_epoch: u64,
525 ) -> Result<bool, MemoryError> {
526 let row: Option<(i64,)> = zeph_db::query_as(sql!(
527 "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
528 ))
529 .bind(entity_name)
530 .fetch_optional(&self.pool)
531 .await?;
532
533 match row {
534 None => Ok(true), Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::in_memory_store::InMemoryVectorStore;
544 use crate::store::SqliteStore;
545
546 async fn setup() -> (SqliteStore, DbPool) {
547 let store = SqliteStore::new(":memory:").await.unwrap();
548 let pool = store.pool().clone();
549 (store, pool)
550 }
551
552 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
553 let sqlite = SqliteStore::new(":memory:").await.unwrap();
554 let pool = sqlite.pool().clone();
555 let mem_store = Box::new(InMemoryVectorStore::new());
556 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
557 embedding_store.ensure_collection(4).await.unwrap();
559 (embedding_store, sqlite)
560 }
561
562 #[tokio::test]
563 async fn has_embedding_returns_false_when_none() {
564 let (_store, pool) = setup().await;
565
566 let row: (i64,) = zeph_db::query_as(sql!(
567 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
568 ))
569 .bind(999_i64)
570 .fetch_one(&pool)
571 .await
572 .unwrap();
573
574 assert_eq!(row.0, 0);
575 }
576
577 #[tokio::test]
578 async fn insert_and_query_embeddings_metadata() {
579 let (sqlite, pool) = setup().await;
580 let cid = sqlite.create_conversation().await.unwrap();
581 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
582
583 let point_id = uuid::Uuid::new_v4().to_string();
584 zeph_db::query(sql!(
585 "INSERT INTO embeddings_metadata \
586 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
587 VALUES (?, ?, ?, ?, ?)"
588 ))
589 .bind(msg_id)
590 .bind(0_i64)
591 .bind(&point_id)
592 .bind(768_i64)
593 .bind("qwen3-embedding")
594 .execute(&pool)
595 .await
596 .unwrap();
597
598 let row: (i64,) = zeph_db::query_as(sql!(
599 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
600 ))
601 .bind(msg_id)
602 .fetch_one(&pool)
603 .await
604 .unwrap();
605 assert_eq!(row.0, 1);
606 }
607
608 #[tokio::test]
609 async fn embedding_store_search_empty_returns_empty() {
610 let (store, _sqlite) = setup_with_store().await;
611 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
612 assert!(results.is_empty());
613 }
614
615 #[tokio::test]
616 async fn embedding_store_store_and_search() {
617 let (store, sqlite) = setup_with_store().await;
618 let cid = sqlite.create_conversation().await.unwrap();
619 let msg_id = sqlite
620 .save_message(cid, "user", "test message")
621 .await
622 .unwrap();
623
624 store
625 .store(
626 msg_id,
627 cid,
628 "user",
629 vec![1.0, 0.0, 0.0, 0.0],
630 MessageKind::Regular,
631 "test-model",
632 0,
633 )
634 .await
635 .unwrap();
636
637 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
638 assert_eq!(results.len(), 1);
639 assert_eq!(results[0].message_id, msg_id);
640 assert_eq!(results[0].conversation_id, cid);
641 assert!((results[0].score - 1.0).abs() < 0.001);
642 }
643
644 #[tokio::test]
645 async fn embedding_store_has_embedding_false_for_unknown() {
646 let (store, sqlite) = setup_with_store().await;
647 let cid = sqlite.create_conversation().await.unwrap();
648 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
649 assert!(!store.has_embedding(msg_id).await.unwrap());
650 }
651
652 #[tokio::test]
653 async fn embedding_store_has_embedding_true_after_store() {
654 let (store, sqlite) = setup_with_store().await;
655 let cid = sqlite.create_conversation().await.unwrap();
656 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
657
658 store
659 .store(
660 msg_id,
661 cid,
662 "user",
663 vec![0.0, 1.0, 0.0, 0.0],
664 MessageKind::Regular,
665 "test-model",
666 0,
667 )
668 .await
669 .unwrap();
670
671 assert!(store.has_embedding(msg_id).await.unwrap());
672 }
673
674 #[tokio::test]
675 async fn embedding_store_search_with_conversation_filter() {
676 let (store, sqlite) = setup_with_store().await;
677 let cid1 = sqlite.create_conversation().await.unwrap();
678 let cid2 = sqlite.create_conversation().await.unwrap();
679 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
680 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
681
682 store
683 .store(
684 msg1,
685 cid1,
686 "user",
687 vec![1.0, 0.0, 0.0, 0.0],
688 MessageKind::Regular,
689 "m",
690 0,
691 )
692 .await
693 .unwrap();
694 store
695 .store(
696 msg2,
697 cid2,
698 "user",
699 vec![1.0, 0.0, 0.0, 0.0],
700 MessageKind::Regular,
701 "m",
702 0,
703 )
704 .await
705 .unwrap();
706
707 let results = store
708 .search(
709 &[1.0, 0.0, 0.0, 0.0],
710 10,
711 Some(SearchFilter {
712 conversation_id: Some(cid1),
713 role: None,
714 }),
715 )
716 .await
717 .unwrap();
718 assert_eq!(results.len(), 1);
719 assert_eq!(results[0].conversation_id, cid1);
720 }
721
722 #[tokio::test]
723 async fn unique_constraint_on_message_chunk_and_model() {
724 let (sqlite, pool) = setup().await;
725 let cid = sqlite.create_conversation().await.unwrap();
726 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
727
728 let point_id1 = uuid::Uuid::new_v4().to_string();
729 zeph_db::query(sql!(
730 "INSERT INTO embeddings_metadata \
731 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
732 VALUES (?, ?, ?, ?, ?)"
733 ))
734 .bind(msg_id)
735 .bind(0_i64)
736 .bind(&point_id1)
737 .bind(768_i64)
738 .bind("qwen3-embedding")
739 .execute(&pool)
740 .await
741 .unwrap();
742
743 let point_id2 = uuid::Uuid::new_v4().to_string();
745 let result = zeph_db::query(sql!(
746 "INSERT INTO embeddings_metadata \
747 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
748 VALUES (?, ?, ?, ?, ?)"
749 ))
750 .bind(msg_id)
751 .bind(0_i64)
752 .bind(&point_id2)
753 .bind(768_i64)
754 .bind("qwen3-embedding")
755 .execute(&pool)
756 .await;
757 assert!(result.is_err());
758
759 let point_id3 = uuid::Uuid::new_v4().to_string();
761 zeph_db::query(sql!(
762 "INSERT INTO embeddings_metadata \
763 (message_id, chunk_index, qdrant_point_id, dimensions, model) \
764 VALUES (?, ?, ?, ?, ?)"
765 ))
766 .bind(msg_id)
767 .bind(1_i64)
768 .bind(&point_id3)
769 .bind(768_i64)
770 .bind("qwen3-embedding")
771 .execute(&pool)
772 .await
773 .unwrap();
774 }
775}