1pub use qdrant_client::qdrant::Filter;
5use zeph_db::DbPool;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::db_vector_store::DbVectorStore;
10use crate::error::MemoryError;
11use crate::qdrant_ops::QdrantOps;
12use crate::types::{ConversationId, MessageId};
13use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MessageKind {
18 Regular,
19 Summary,
20}
21
22impl MessageKind {
23 #[must_use]
24 pub fn is_summary(self) -> bool {
25 matches!(self, Self::Summary)
26 }
27}
28
29const COLLECTION_NAME: &str = "zeph_conversations";
30
31pub async fn ensure_qdrant_collection(
39 ops: &QdrantOps,
40 collection: &str,
41 vector_size: u64,
42) -> Result<(), Box<qdrant_client::QdrantError>> {
43 ops.ensure_collection(collection, vector_size).await
44}
45
46pub struct EmbeddingStore {
47 ops: Box<dyn VectorStore>,
48 collection: String,
49 pool: DbPool,
50}
51
52impl std::fmt::Debug for EmbeddingStore {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("EmbeddingStore")
55 .field("collection", &self.collection)
56 .finish_non_exhaustive()
57 }
58}
59
60#[derive(Debug)]
61pub struct SearchFilter {
62 pub conversation_id: Option<ConversationId>,
63 pub role: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct SearchResult {
68 pub message_id: MessageId,
69 pub conversation_id: ConversationId,
70 pub score: f32,
71}
72
73impl EmbeddingStore {
74 pub fn new(url: &str, pool: DbPool) -> Result<Self, MemoryError> {
83 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
84
85 Ok(Self {
86 ops: Box::new(ops),
87 collection: COLLECTION_NAME.into(),
88 pool,
89 })
90 }
91
92 #[must_use]
96 pub fn new_sqlite(pool: DbPool) -> Self {
97 let ops = DbVectorStore::new(pool.clone());
98 Self {
99 ops: Box::new(ops),
100 collection: COLLECTION_NAME.into(),
101 pool,
102 }
103 }
104
105 #[must_use]
106 pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
107 Self {
108 ops: store,
109 collection: COLLECTION_NAME.into(),
110 pool,
111 }
112 }
113
114 pub async fn health_check(&self) -> bool {
115 self.ops.health_check().await.unwrap_or(false)
116 }
117
118 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
126 self.ops
127 .ensure_collection(&self.collection, vector_size)
128 .await?;
129 Ok(())
130 }
131
132 pub async fn store(
140 &self,
141 message_id: MessageId,
142 conversation_id: ConversationId,
143 role: &str,
144 vector: Vec<f32>,
145 kind: MessageKind,
146 model: &str,
147 ) -> Result<String, MemoryError> {
148 let point_id = uuid::Uuid::new_v4().to_string();
149 let dimensions = i64::try_from(vector.len())?;
150
151 let payload = std::collections::HashMap::from([
152 ("message_id".to_owned(), serde_json::json!(message_id.0)),
153 (
154 "conversation_id".to_owned(),
155 serde_json::json!(conversation_id.0),
156 ),
157 ("role".to_owned(), serde_json::json!(role)),
158 (
159 "is_summary".to_owned(),
160 serde_json::json!(kind.is_summary()),
161 ),
162 ]);
163
164 let point = VectorPoint {
165 id: point_id.clone(),
166 vector,
167 payload,
168 };
169
170 self.ops.upsert(&self.collection, vec![point]).await?;
171
172 zeph_db::query(sql!(
173 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
174 VALUES (?, ?, ?, ?) \
175 ON CONFLICT(message_id, model) DO UPDATE SET \
176 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
177 ))
178 .bind(message_id)
179 .bind(&point_id)
180 .bind(dimensions)
181 .bind(model)
182 .execute(&self.pool)
183 .await?;
184
185 Ok(point_id)
186 }
187
188 pub async fn search(
194 &self,
195 query_vector: &[f32],
196 limit: usize,
197 filter: Option<SearchFilter>,
198 ) -> Result<Vec<SearchResult>, MemoryError> {
199 let limit_u64 = u64::try_from(limit)?;
200
201 let vector_filter = filter.as_ref().and_then(|f| {
202 let mut must = Vec::new();
203 if let Some(cid) = f.conversation_id {
204 must.push(FieldCondition {
205 field: "conversation_id".into(),
206 value: FieldValue::Integer(cid.0),
207 });
208 }
209 if let Some(ref role) = f.role {
210 must.push(FieldCondition {
211 field: "role".into(),
212 value: FieldValue::Text(role.clone()),
213 });
214 }
215 if must.is_empty() {
216 None
217 } else {
218 Some(VectorFilter {
219 must,
220 must_not: vec![],
221 })
222 }
223 });
224
225 let results = self
226 .ops
227 .search(
228 &self.collection,
229 query_vector.to_vec(),
230 limit_u64,
231 vector_filter,
232 )
233 .await?;
234
235 let search_results = results
236 .into_iter()
237 .filter_map(|point| {
238 let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
239 let conversation_id =
240 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
241 Some(SearchResult {
242 message_id,
243 conversation_id,
244 score: point.score,
245 })
246 })
247 .collect();
248
249 Ok(search_results)
250 }
251
252 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
258 self.ops.collection_exists(name).await.map_err(Into::into)
259 }
260
261 pub async fn ensure_named_collection(
267 &self,
268 name: &str,
269 vector_size: u64,
270 ) -> Result<(), MemoryError> {
271 self.ops.ensure_collection(name, vector_size).await?;
272 Ok(())
273 }
274
275 pub async fn store_to_collection(
283 &self,
284 collection: &str,
285 payload: serde_json::Value,
286 vector: Vec<f32>,
287 ) -> Result<String, MemoryError> {
288 let point_id = uuid::Uuid::new_v4().to_string();
289 let payload_map: std::collections::HashMap<String, serde_json::Value> =
290 serde_json::from_value(payload)?;
291 let point = VectorPoint {
292 id: point_id.clone(),
293 vector,
294 payload: payload_map,
295 };
296 self.ops.upsert(collection, vec![point]).await?;
297 Ok(point_id)
298 }
299
300 pub async fn upsert_to_collection(
308 &self,
309 collection: &str,
310 point_id: &str,
311 payload: serde_json::Value,
312 vector: Vec<f32>,
313 ) -> Result<(), MemoryError> {
314 let payload_map: std::collections::HashMap<String, serde_json::Value> =
315 serde_json::from_value(payload)?;
316 let point = VectorPoint {
317 id: point_id.to_owned(),
318 vector,
319 payload: payload_map,
320 };
321 self.ops.upsert(collection, vec![point]).await?;
322 Ok(())
323 }
324
325 pub async fn search_collection(
331 &self,
332 collection: &str,
333 query_vector: &[f32],
334 limit: usize,
335 filter: Option<VectorFilter>,
336 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
337 let limit_u64 = u64::try_from(limit)?;
338 let results = self
339 .ops
340 .search(collection, query_vector.to_vec(), limit_u64, filter)
341 .await?;
342 Ok(results)
343 }
344
345 pub async fn get_vectors(
353 &self,
354 ids: &[MessageId],
355 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
356 if ids.is_empty() {
357 return Ok(std::collections::HashMap::new());
358 }
359
360 let placeholders = zeph_db::placeholder_list(1, ids.len());
361 let query = format!(
362 "SELECT em.message_id, vp.vector \
363 FROM embeddings_metadata em \
364 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
365 WHERE em.message_id IN ({placeholders})"
366 );
367 let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
368 for &id in ids {
369 q = q.bind(id);
370 }
371
372 let rows = q.fetch_all(&self.pool).await?;
373
374 let map = rows
375 .into_iter()
376 .filter_map(|(msg_id, blob)| {
377 if blob.len() % 4 != 0 {
378 return None;
379 }
380 let vec: Vec<f32> = blob
381 .chunks_exact(4)
382 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
383 .collect();
384 Some((msg_id, vec))
385 })
386 .collect();
387
388 Ok(map)
389 }
390
391 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
397 let row: (i64,) = zeph_db::query_as(sql!(
398 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
399 ))
400 .bind(message_id)
401 .fetch_one(&self.pool)
402 .await?;
403
404 Ok(row.0 > 0)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::in_memory_store::InMemoryVectorStore;
412 use crate::store::SqliteStore;
413
414 async fn setup() -> (SqliteStore, DbPool) {
415 let store = SqliteStore::new(":memory:").await.unwrap();
416 let pool = store.pool().clone();
417 (store, pool)
418 }
419
420 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
421 let sqlite = SqliteStore::new(":memory:").await.unwrap();
422 let pool = sqlite.pool().clone();
423 let mem_store = Box::new(InMemoryVectorStore::new());
424 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
425 embedding_store.ensure_collection(4).await.unwrap();
427 (embedding_store, sqlite)
428 }
429
430 #[tokio::test]
431 async fn has_embedding_returns_false_when_none() {
432 let (_store, pool) = setup().await;
433
434 let row: (i64,) = zeph_db::query_as(sql!(
435 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
436 ))
437 .bind(999_i64)
438 .fetch_one(&pool)
439 .await
440 .unwrap();
441
442 assert_eq!(row.0, 0);
443 }
444
445 #[tokio::test]
446 async fn insert_and_query_embeddings_metadata() {
447 let (sqlite, pool) = setup().await;
448 let cid = sqlite.create_conversation().await.unwrap();
449 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
450
451 let point_id = uuid::Uuid::new_v4().to_string();
452 zeph_db::query(sql!(
453 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
454 VALUES (?, ?, ?, ?)"
455 ))
456 .bind(msg_id)
457 .bind(&point_id)
458 .bind(768_i64)
459 .bind("qwen3-embedding")
460 .execute(&pool)
461 .await
462 .unwrap();
463
464 let row: (i64,) = zeph_db::query_as(sql!(
465 "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
466 ))
467 .bind(msg_id)
468 .fetch_one(&pool)
469 .await
470 .unwrap();
471 assert_eq!(row.0, 1);
472 }
473
474 #[tokio::test]
475 async fn embedding_store_search_empty_returns_empty() {
476 let (store, _sqlite) = setup_with_store().await;
477 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
478 assert!(results.is_empty());
479 }
480
481 #[tokio::test]
482 async fn embedding_store_store_and_search() {
483 let (store, sqlite) = setup_with_store().await;
484 let cid = sqlite.create_conversation().await.unwrap();
485 let msg_id = sqlite
486 .save_message(cid, "user", "test message")
487 .await
488 .unwrap();
489
490 store
491 .store(
492 msg_id,
493 cid,
494 "user",
495 vec![1.0, 0.0, 0.0, 0.0],
496 MessageKind::Regular,
497 "test-model",
498 )
499 .await
500 .unwrap();
501
502 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
503 assert_eq!(results.len(), 1);
504 assert_eq!(results[0].message_id, msg_id);
505 assert_eq!(results[0].conversation_id, cid);
506 assert!((results[0].score - 1.0).abs() < 0.001);
507 }
508
509 #[tokio::test]
510 async fn embedding_store_has_embedding_false_for_unknown() {
511 let (store, sqlite) = setup_with_store().await;
512 let cid = sqlite.create_conversation().await.unwrap();
513 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
514 assert!(!store.has_embedding(msg_id).await.unwrap());
515 }
516
517 #[tokio::test]
518 async fn embedding_store_has_embedding_true_after_store() {
519 let (store, sqlite) = setup_with_store().await;
520 let cid = sqlite.create_conversation().await.unwrap();
521 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
522
523 store
524 .store(
525 msg_id,
526 cid,
527 "user",
528 vec![0.0, 1.0, 0.0, 0.0],
529 MessageKind::Regular,
530 "test-model",
531 )
532 .await
533 .unwrap();
534
535 assert!(store.has_embedding(msg_id).await.unwrap());
536 }
537
538 #[tokio::test]
539 async fn embedding_store_search_with_conversation_filter() {
540 let (store, sqlite) = setup_with_store().await;
541 let cid1 = sqlite.create_conversation().await.unwrap();
542 let cid2 = sqlite.create_conversation().await.unwrap();
543 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
544 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
545
546 store
547 .store(
548 msg1,
549 cid1,
550 "user",
551 vec![1.0, 0.0, 0.0, 0.0],
552 MessageKind::Regular,
553 "m",
554 )
555 .await
556 .unwrap();
557 store
558 .store(
559 msg2,
560 cid2,
561 "user",
562 vec![1.0, 0.0, 0.0, 0.0],
563 MessageKind::Regular,
564 "m",
565 )
566 .await
567 .unwrap();
568
569 let results = store
570 .search(
571 &[1.0, 0.0, 0.0, 0.0],
572 10,
573 Some(SearchFilter {
574 conversation_id: Some(cid1),
575 role: None,
576 }),
577 )
578 .await
579 .unwrap();
580 assert_eq!(results.len(), 1);
581 assert_eq!(results[0].conversation_id, cid1);
582 }
583
584 #[tokio::test]
585 async fn unique_constraint_on_message_and_model() {
586 let (sqlite, pool) = setup().await;
587 let cid = sqlite.create_conversation().await.unwrap();
588 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
589
590 let point_id1 = uuid::Uuid::new_v4().to_string();
591 zeph_db::query(sql!(
592 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
593 VALUES (?, ?, ?, ?)"
594 ))
595 .bind(msg_id)
596 .bind(&point_id1)
597 .bind(768_i64)
598 .bind("qwen3-embedding")
599 .execute(&pool)
600 .await
601 .unwrap();
602
603 let point_id2 = uuid::Uuid::new_v4().to_string();
604 let result = zeph_db::query(sql!(
605 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
606 VALUES (?, ?, ?, ?)"
607 ))
608 .bind(msg_id)
609 .bind(&point_id2)
610 .bind(768_i64)
611 .bind("qwen3-embedding")
612 .execute(&pool)
613 .await;
614
615 assert!(result.is_err());
616 }
617}