1pub use qdrant_client::qdrant::Filter;
5use sqlx::SqlitePool;
6
7use crate::error::MemoryError;
8use crate::qdrant_ops::QdrantOps;
9use crate::sqlite_vector_store::SqliteVectorStore;
10use crate::types::{ConversationId, MessageId};
11use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum MessageKind {
16 Regular,
17 Summary,
18}
19
20impl MessageKind {
21 #[must_use]
22 pub fn is_summary(self) -> bool {
23 matches!(self, Self::Summary)
24 }
25}
26
27const COLLECTION_NAME: &str = "zeph_conversations";
28
29pub async fn ensure_qdrant_collection(
37 ops: &QdrantOps,
38 collection: &str,
39 vector_size: u64,
40) -> Result<(), Box<qdrant_client::QdrantError>> {
41 ops.ensure_collection(collection, vector_size).await
42}
43
44pub struct EmbeddingStore {
45 ops: Box<dyn VectorStore>,
46 collection: String,
47 pool: SqlitePool,
48}
49
50impl std::fmt::Debug for EmbeddingStore {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("EmbeddingStore")
53 .field("collection", &self.collection)
54 .finish_non_exhaustive()
55 }
56}
57
58#[derive(Debug)]
59pub struct SearchFilter {
60 pub conversation_id: Option<ConversationId>,
61 pub role: Option<String>,
62}
63
64#[derive(Debug)]
65pub struct SearchResult {
66 pub message_id: MessageId,
67 pub conversation_id: ConversationId,
68 pub score: f32,
69}
70
71impl EmbeddingStore {
72 pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
81 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
82
83 Ok(Self {
84 ops: Box::new(ops),
85 collection: COLLECTION_NAME.into(),
86 pool,
87 })
88 }
89
90 #[must_use]
94 pub fn new_sqlite(pool: SqlitePool) -> Self {
95 let ops = SqliteVectorStore::new(pool.clone());
96 Self {
97 ops: Box::new(ops),
98 collection: COLLECTION_NAME.into(),
99 pool,
100 }
101 }
102
103 #[must_use]
104 pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
105 Self {
106 ops: store,
107 collection: COLLECTION_NAME.into(),
108 pool,
109 }
110 }
111
112 pub async fn health_check(&self) -> bool {
113 self.ops.health_check().await.unwrap_or(false)
114 }
115
116 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
124 self.ops
125 .ensure_collection(&self.collection, vector_size)
126 .await?;
127 Ok(())
128 }
129
130 pub async fn store(
138 &self,
139 message_id: MessageId,
140 conversation_id: ConversationId,
141 role: &str,
142 vector: Vec<f32>,
143 kind: MessageKind,
144 model: &str,
145 ) -> Result<String, MemoryError> {
146 let point_id = uuid::Uuid::new_v4().to_string();
147 let dimensions = i64::try_from(vector.len())?;
148
149 let payload = std::collections::HashMap::from([
150 ("message_id".to_owned(), serde_json::json!(message_id.0)),
151 (
152 "conversation_id".to_owned(),
153 serde_json::json!(conversation_id.0),
154 ),
155 ("role".to_owned(), serde_json::json!(role)),
156 (
157 "is_summary".to_owned(),
158 serde_json::json!(kind.is_summary()),
159 ),
160 ]);
161
162 let point = VectorPoint {
163 id: point_id.clone(),
164 vector,
165 payload,
166 };
167
168 self.ops.upsert(&self.collection, vec![point]).await?;
169
170 sqlx::query(
171 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
172 VALUES (?, ?, ?, ?) \
173 ON CONFLICT(message_id, model) DO UPDATE SET \
174 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
175 )
176 .bind(message_id)
177 .bind(&point_id)
178 .bind(dimensions)
179 .bind(model)
180 .execute(&self.pool)
181 .await?;
182
183 Ok(point_id)
184 }
185
186 pub async fn search(
192 &self,
193 query_vector: &[f32],
194 limit: usize,
195 filter: Option<SearchFilter>,
196 ) -> Result<Vec<SearchResult>, MemoryError> {
197 let limit_u64 = u64::try_from(limit)?;
198
199 let vector_filter = filter.as_ref().and_then(|f| {
200 let mut must = Vec::new();
201 if let Some(cid) = f.conversation_id {
202 must.push(FieldCondition {
203 field: "conversation_id".into(),
204 value: FieldValue::Integer(cid.0),
205 });
206 }
207 if let Some(ref role) = f.role {
208 must.push(FieldCondition {
209 field: "role".into(),
210 value: FieldValue::Text(role.clone()),
211 });
212 }
213 if must.is_empty() {
214 None
215 } else {
216 Some(VectorFilter {
217 must,
218 must_not: vec![],
219 })
220 }
221 });
222
223 let results = self
224 .ops
225 .search(
226 &self.collection,
227 query_vector.to_vec(),
228 limit_u64,
229 vector_filter,
230 )
231 .await?;
232
233 let search_results = results
234 .into_iter()
235 .filter_map(|point| {
236 let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
237 let conversation_id =
238 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
239 Some(SearchResult {
240 message_id,
241 conversation_id,
242 score: point.score,
243 })
244 })
245 .collect();
246
247 Ok(search_results)
248 }
249
250 pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
256 self.ops.collection_exists(name).await.map_err(Into::into)
257 }
258
259 pub async fn ensure_named_collection(
265 &self,
266 name: &str,
267 vector_size: u64,
268 ) -> Result<(), MemoryError> {
269 self.ops.ensure_collection(name, vector_size).await?;
270 Ok(())
271 }
272
273 pub async fn store_to_collection(
281 &self,
282 collection: &str,
283 payload: serde_json::Value,
284 vector: Vec<f32>,
285 ) -> Result<String, MemoryError> {
286 let point_id = uuid::Uuid::new_v4().to_string();
287 let payload_map: std::collections::HashMap<String, serde_json::Value> =
288 serde_json::from_value(payload)?;
289 let point = VectorPoint {
290 id: point_id.clone(),
291 vector,
292 payload: payload_map,
293 };
294 self.ops.upsert(collection, vec![point]).await?;
295 Ok(point_id)
296 }
297
298 pub async fn upsert_to_collection(
306 &self,
307 collection: &str,
308 point_id: &str,
309 payload: serde_json::Value,
310 vector: Vec<f32>,
311 ) -> Result<(), MemoryError> {
312 let payload_map: std::collections::HashMap<String, serde_json::Value> =
313 serde_json::from_value(payload)?;
314 let point = VectorPoint {
315 id: point_id.to_owned(),
316 vector,
317 payload: payload_map,
318 };
319 self.ops.upsert(collection, vec![point]).await?;
320 Ok(())
321 }
322
323 pub async fn search_collection(
329 &self,
330 collection: &str,
331 query_vector: &[f32],
332 limit: usize,
333 filter: Option<VectorFilter>,
334 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
335 let limit_u64 = u64::try_from(limit)?;
336 let results = self
337 .ops
338 .search(collection, query_vector.to_vec(), limit_u64, filter)
339 .await?;
340 Ok(results)
341 }
342
343 pub async fn get_vectors(
351 &self,
352 ids: &[MessageId],
353 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
354 if ids.is_empty() {
355 return Ok(std::collections::HashMap::new());
356 }
357
358 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
359 let query = format!(
360 "SELECT em.message_id, vp.vector \
361 FROM embeddings_metadata em \
362 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
363 WHERE em.message_id IN ({placeholders})"
364 );
365 let mut q = sqlx::query_as::<_, (MessageId, Vec<u8>)>(&query);
366 for &id in ids {
367 q = q.bind(id);
368 }
369
370 let rows = q.fetch_all(&self.pool).await.unwrap_or_default();
371
372 let map = rows
373 .into_iter()
374 .filter_map(|(msg_id, blob)| {
375 if blob.len() % 4 != 0 {
376 return None;
377 }
378 let vec: Vec<f32> = blob
379 .chunks_exact(4)
380 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
381 .collect();
382 Some((msg_id, vec))
383 })
384 .collect();
385
386 Ok(map)
387 }
388
389 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
395 let row: (i64,) =
396 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
397 .bind(message_id)
398 .fetch_one(&self.pool)
399 .await?;
400
401 Ok(row.0 > 0)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::in_memory_store::InMemoryVectorStore;
409 use crate::sqlite::SqliteStore;
410
411 async fn setup() -> (SqliteStore, SqlitePool) {
412 let store = SqliteStore::new(":memory:").await.unwrap();
413 let pool = store.pool().clone();
414 (store, pool)
415 }
416
417 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
418 let sqlite = SqliteStore::new(":memory:").await.unwrap();
419 let pool = sqlite.pool().clone();
420 let mem_store = Box::new(InMemoryVectorStore::new());
421 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
422 embedding_store.ensure_collection(4).await.unwrap();
424 (embedding_store, sqlite)
425 }
426
427 #[tokio::test]
428 async fn has_embedding_returns_false_when_none() {
429 let (_store, pool) = setup().await;
430
431 let row: (i64,) =
432 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
433 .bind(999_i64)
434 .fetch_one(&pool)
435 .await
436 .unwrap();
437
438 assert_eq!(row.0, 0);
439 }
440
441 #[tokio::test]
442 async fn insert_and_query_embeddings_metadata() {
443 let (sqlite, pool) = setup().await;
444 let cid = sqlite.create_conversation().await.unwrap();
445 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
446
447 let point_id = uuid::Uuid::new_v4().to_string();
448 sqlx::query(
449 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
450 VALUES (?, ?, ?, ?)",
451 )
452 .bind(msg_id)
453 .bind(&point_id)
454 .bind(768_i64)
455 .bind("qwen3-embedding")
456 .execute(&pool)
457 .await
458 .unwrap();
459
460 let row: (i64,) =
461 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
462 .bind(msg_id)
463 .fetch_one(&pool)
464 .await
465 .unwrap();
466 assert_eq!(row.0, 1);
467 }
468
469 #[tokio::test]
470 async fn embedding_store_search_empty_returns_empty() {
471 let (store, _sqlite) = setup_with_store().await;
472 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
473 assert!(results.is_empty());
474 }
475
476 #[tokio::test]
477 async fn embedding_store_store_and_search() {
478 let (store, sqlite) = setup_with_store().await;
479 let cid = sqlite.create_conversation().await.unwrap();
480 let msg_id = sqlite
481 .save_message(cid, "user", "test message")
482 .await
483 .unwrap();
484
485 store
486 .store(
487 msg_id,
488 cid,
489 "user",
490 vec![1.0, 0.0, 0.0, 0.0],
491 MessageKind::Regular,
492 "test-model",
493 )
494 .await
495 .unwrap();
496
497 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
498 assert_eq!(results.len(), 1);
499 assert_eq!(results[0].message_id, msg_id);
500 assert_eq!(results[0].conversation_id, cid);
501 assert!((results[0].score - 1.0).abs() < 0.001);
502 }
503
504 #[tokio::test]
505 async fn embedding_store_has_embedding_false_for_unknown() {
506 let (store, sqlite) = setup_with_store().await;
507 let cid = sqlite.create_conversation().await.unwrap();
508 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
509 assert!(!store.has_embedding(msg_id).await.unwrap());
510 }
511
512 #[tokio::test]
513 async fn embedding_store_has_embedding_true_after_store() {
514 let (store, sqlite) = setup_with_store().await;
515 let cid = sqlite.create_conversation().await.unwrap();
516 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
517
518 store
519 .store(
520 msg_id,
521 cid,
522 "user",
523 vec![0.0, 1.0, 0.0, 0.0],
524 MessageKind::Regular,
525 "test-model",
526 )
527 .await
528 .unwrap();
529
530 assert!(store.has_embedding(msg_id).await.unwrap());
531 }
532
533 #[tokio::test]
534 async fn embedding_store_search_with_conversation_filter() {
535 let (store, sqlite) = setup_with_store().await;
536 let cid1 = sqlite.create_conversation().await.unwrap();
537 let cid2 = sqlite.create_conversation().await.unwrap();
538 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
539 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
540
541 store
542 .store(
543 msg1,
544 cid1,
545 "user",
546 vec![1.0, 0.0, 0.0, 0.0],
547 MessageKind::Regular,
548 "m",
549 )
550 .await
551 .unwrap();
552 store
553 .store(
554 msg2,
555 cid2,
556 "user",
557 vec![1.0, 0.0, 0.0, 0.0],
558 MessageKind::Regular,
559 "m",
560 )
561 .await
562 .unwrap();
563
564 let results = store
565 .search(
566 &[1.0, 0.0, 0.0, 0.0],
567 10,
568 Some(SearchFilter {
569 conversation_id: Some(cid1),
570 role: None,
571 }),
572 )
573 .await
574 .unwrap();
575 assert_eq!(results.len(), 1);
576 assert_eq!(results[0].conversation_id, cid1);
577 }
578
579 #[tokio::test]
580 async fn unique_constraint_on_message_and_model() {
581 let (sqlite, pool) = setup().await;
582 let cid = sqlite.create_conversation().await.unwrap();
583 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
584
585 let point_id1 = uuid::Uuid::new_v4().to_string();
586 sqlx::query(
587 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
588 VALUES (?, ?, ?, ?)",
589 )
590 .bind(msg_id)
591 .bind(&point_id1)
592 .bind(768_i64)
593 .bind("qwen3-embedding")
594 .execute(&pool)
595 .await
596 .unwrap();
597
598 let point_id2 = uuid::Uuid::new_v4().to_string();
599 let result = sqlx::query(
600 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
601 VALUES (?, ?, ?, ?)",
602 )
603 .bind(msg_id)
604 .bind(&point_id2)
605 .bind(768_i64)
606 .bind("qwen3-embedding")
607 .execute(&pool)
608 .await;
609
610 assert!(result.is_err());
611 }
612}