1pub use qdrant_client::qdrant::Filter;
2use sqlx::SqlitePool;
3
4use crate::error::MemoryError;
5use crate::qdrant_ops::QdrantOps;
6use crate::sqlite_vector_store::SqliteVectorStore;
7use crate::types::{ConversationId, MessageId};
8use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MessageKind {
13 Regular,
14 Summary,
15}
16
17impl MessageKind {
18 #[must_use]
19 pub fn is_summary(self) -> bool {
20 matches!(self, Self::Summary)
21 }
22}
23
24const COLLECTION_NAME: &str = "zeph_conversations";
25
26pub async fn ensure_qdrant_collection(
34 ops: &QdrantOps,
35 collection: &str,
36 vector_size: u64,
37) -> Result<(), Box<qdrant_client::QdrantError>> {
38 ops.ensure_collection(collection, vector_size).await
39}
40
41pub struct EmbeddingStore {
42 ops: Box<dyn VectorStore>,
43 collection: String,
44 pool: SqlitePool,
45}
46
47impl std::fmt::Debug for EmbeddingStore {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("EmbeddingStore")
50 .field("collection", &self.collection)
51 .finish_non_exhaustive()
52 }
53}
54
55#[derive(Debug)]
56pub struct SearchFilter {
57 pub conversation_id: Option<ConversationId>,
58 pub role: Option<String>,
59}
60
61#[derive(Debug)]
62pub struct SearchResult {
63 pub message_id: MessageId,
64 pub conversation_id: ConversationId,
65 pub score: f32,
66}
67
68impl EmbeddingStore {
69 pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
78 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
79
80 Ok(Self {
81 ops: Box::new(ops),
82 collection: COLLECTION_NAME.into(),
83 pool,
84 })
85 }
86
87 #[must_use]
91 pub fn new_sqlite(pool: SqlitePool) -> Self {
92 let ops = SqliteVectorStore::new(pool.clone());
93 Self {
94 ops: Box::new(ops),
95 collection: COLLECTION_NAME.into(),
96 pool,
97 }
98 }
99
100 #[must_use]
101 pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
102 Self {
103 ops: store,
104 collection: COLLECTION_NAME.into(),
105 pool,
106 }
107 }
108
109 pub async fn health_check(&self) -> bool {
110 self.ops.health_check().await.unwrap_or(false)
111 }
112
113 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
121 self.ops
122 .ensure_collection(&self.collection, vector_size)
123 .await?;
124 Ok(())
125 }
126
127 pub async fn store(
135 &self,
136 message_id: MessageId,
137 conversation_id: ConversationId,
138 role: &str,
139 vector: Vec<f32>,
140 kind: MessageKind,
141 model: &str,
142 ) -> Result<String, MemoryError> {
143 let point_id = uuid::Uuid::new_v4().to_string();
144 let dimensions = i64::try_from(vector.len())?;
145
146 let payload = std::collections::HashMap::from([
147 ("message_id".to_owned(), serde_json::json!(message_id.0)),
148 (
149 "conversation_id".to_owned(),
150 serde_json::json!(conversation_id.0),
151 ),
152 ("role".to_owned(), serde_json::json!(role)),
153 (
154 "is_summary".to_owned(),
155 serde_json::json!(kind.is_summary()),
156 ),
157 ]);
158
159 let point = VectorPoint {
160 id: point_id.clone(),
161 vector,
162 payload,
163 };
164
165 self.ops.upsert(&self.collection, vec![point]).await?;
166
167 sqlx::query(
168 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
169 VALUES (?, ?, ?, ?) \
170 ON CONFLICT(message_id, model) DO UPDATE SET \
171 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
172 )
173 .bind(message_id)
174 .bind(&point_id)
175 .bind(dimensions)
176 .bind(model)
177 .execute(&self.pool)
178 .await?;
179
180 Ok(point_id)
181 }
182
183 pub async fn search(
189 &self,
190 query_vector: &[f32],
191 limit: usize,
192 filter: Option<SearchFilter>,
193 ) -> Result<Vec<SearchResult>, MemoryError> {
194 let limit_u64 = u64::try_from(limit)?;
195
196 let vector_filter = filter.as_ref().and_then(|f| {
197 let mut must = Vec::new();
198 if let Some(cid) = f.conversation_id {
199 must.push(FieldCondition {
200 field: "conversation_id".into(),
201 value: FieldValue::Integer(cid.0),
202 });
203 }
204 if let Some(ref role) = f.role {
205 must.push(FieldCondition {
206 field: "role".into(),
207 value: FieldValue::Text(role.clone()),
208 });
209 }
210 if must.is_empty() {
211 None
212 } else {
213 Some(VectorFilter {
214 must,
215 must_not: vec![],
216 })
217 }
218 });
219
220 let results = self
221 .ops
222 .search(
223 &self.collection,
224 query_vector.to_vec(),
225 limit_u64,
226 vector_filter,
227 )
228 .await?;
229
230 let search_results = results
231 .into_iter()
232 .filter_map(|point| {
233 let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
234 let conversation_id =
235 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
236 Some(SearchResult {
237 message_id,
238 conversation_id,
239 score: point.score,
240 })
241 })
242 .collect();
243
244 Ok(search_results)
245 }
246
247 pub async fn ensure_named_collection(
253 &self,
254 name: &str,
255 vector_size: u64,
256 ) -> Result<(), MemoryError> {
257 self.ops.ensure_collection(name, vector_size).await?;
258 Ok(())
259 }
260
261 pub async fn store_to_collection(
269 &self,
270 collection: &str,
271 payload: serde_json::Value,
272 vector: Vec<f32>,
273 ) -> Result<String, MemoryError> {
274 let point_id = uuid::Uuid::new_v4().to_string();
275 let payload_map: std::collections::HashMap<String, serde_json::Value> =
276 serde_json::from_value(payload)?;
277 let point = VectorPoint {
278 id: point_id.clone(),
279 vector,
280 payload: payload_map,
281 };
282 self.ops.upsert(collection, vec![point]).await?;
283 Ok(point_id)
284 }
285
286 pub async fn search_collection(
292 &self,
293 collection: &str,
294 query_vector: &[f32],
295 limit: usize,
296 filter: Option<VectorFilter>,
297 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
298 let limit_u64 = u64::try_from(limit)?;
299 let results = self
300 .ops
301 .search(collection, query_vector.to_vec(), limit_u64, filter)
302 .await?;
303 Ok(results)
304 }
305
306 pub async fn get_vectors(
314 &self,
315 ids: &[MessageId],
316 ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
317 if ids.is_empty() {
318 return Ok(std::collections::HashMap::new());
319 }
320
321 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
322 let query = format!(
323 "SELECT em.message_id, vp.vector \
324 FROM embeddings_metadata em \
325 JOIN vector_points vp ON vp.id = em.qdrant_point_id \
326 WHERE em.message_id IN ({placeholders})"
327 );
328 let mut q = sqlx::query_as::<_, (MessageId, Vec<u8>)>(&query);
329 for &id in ids {
330 q = q.bind(id);
331 }
332
333 let rows = q.fetch_all(&self.pool).await.unwrap_or_default();
334
335 let map = rows
336 .into_iter()
337 .filter_map(|(msg_id, blob)| {
338 if blob.len() % 4 != 0 {
339 return None;
340 }
341 let vec: Vec<f32> = blob
342 .chunks_exact(4)
343 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
344 .collect();
345 Some((msg_id, vec))
346 })
347 .collect();
348
349 Ok(map)
350 }
351
352 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
358 let row: (i64,) =
359 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
360 .bind(message_id)
361 .fetch_one(&self.pool)
362 .await?;
363
364 Ok(row.0 > 0)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::in_memory_store::InMemoryVectorStore;
372 use crate::sqlite::SqliteStore;
373
374 async fn setup() -> (SqliteStore, SqlitePool) {
375 let store = SqliteStore::new(":memory:").await.unwrap();
376 let pool = store.pool().clone();
377 (store, pool)
378 }
379
380 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
381 let sqlite = SqliteStore::new(":memory:").await.unwrap();
382 let pool = sqlite.pool().clone();
383 let mem_store = Box::new(InMemoryVectorStore::new());
384 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
385 embedding_store.ensure_collection(4).await.unwrap();
387 (embedding_store, sqlite)
388 }
389
390 #[tokio::test]
391 async fn has_embedding_returns_false_when_none() {
392 let (_store, pool) = setup().await;
393
394 let row: (i64,) =
395 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
396 .bind(999_i64)
397 .fetch_one(&pool)
398 .await
399 .unwrap();
400
401 assert_eq!(row.0, 0);
402 }
403
404 #[tokio::test]
405 async fn insert_and_query_embeddings_metadata() {
406 let (sqlite, pool) = setup().await;
407 let cid = sqlite.create_conversation().await.unwrap();
408 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
409
410 let point_id = uuid::Uuid::new_v4().to_string();
411 sqlx::query(
412 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
413 VALUES (?, ?, ?, ?)",
414 )
415 .bind(msg_id)
416 .bind(&point_id)
417 .bind(768_i64)
418 .bind("qwen3-embedding")
419 .execute(&pool)
420 .await
421 .unwrap();
422
423 let row: (i64,) =
424 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
425 .bind(msg_id)
426 .fetch_one(&pool)
427 .await
428 .unwrap();
429 assert_eq!(row.0, 1);
430 }
431
432 #[tokio::test]
433 async fn embedding_store_search_empty_returns_empty() {
434 let (store, _sqlite) = setup_with_store().await;
435 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
436 assert!(results.is_empty());
437 }
438
439 #[tokio::test]
440 async fn embedding_store_store_and_search() {
441 let (store, sqlite) = setup_with_store().await;
442 let cid = sqlite.create_conversation().await.unwrap();
443 let msg_id = sqlite
444 .save_message(cid, "user", "test message")
445 .await
446 .unwrap();
447
448 store
449 .store(
450 msg_id,
451 cid,
452 "user",
453 vec![1.0, 0.0, 0.0, 0.0],
454 MessageKind::Regular,
455 "test-model",
456 )
457 .await
458 .unwrap();
459
460 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
461 assert_eq!(results.len(), 1);
462 assert_eq!(results[0].message_id, msg_id);
463 assert_eq!(results[0].conversation_id, cid);
464 assert!((results[0].score - 1.0).abs() < 0.001);
465 }
466
467 #[tokio::test]
468 async fn embedding_store_has_embedding_false_for_unknown() {
469 let (store, sqlite) = setup_with_store().await;
470 let cid = sqlite.create_conversation().await.unwrap();
471 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
472 assert!(!store.has_embedding(msg_id).await.unwrap());
473 }
474
475 #[tokio::test]
476 async fn embedding_store_has_embedding_true_after_store() {
477 let (store, sqlite) = setup_with_store().await;
478 let cid = sqlite.create_conversation().await.unwrap();
479 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
480
481 store
482 .store(
483 msg_id,
484 cid,
485 "user",
486 vec![0.0, 1.0, 0.0, 0.0],
487 MessageKind::Regular,
488 "test-model",
489 )
490 .await
491 .unwrap();
492
493 assert!(store.has_embedding(msg_id).await.unwrap());
494 }
495
496 #[tokio::test]
497 async fn embedding_store_search_with_conversation_filter() {
498 let (store, sqlite) = setup_with_store().await;
499 let cid1 = sqlite.create_conversation().await.unwrap();
500 let cid2 = sqlite.create_conversation().await.unwrap();
501 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
502 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
503
504 store
505 .store(
506 msg1,
507 cid1,
508 "user",
509 vec![1.0, 0.0, 0.0, 0.0],
510 MessageKind::Regular,
511 "m",
512 )
513 .await
514 .unwrap();
515 store
516 .store(
517 msg2,
518 cid2,
519 "user",
520 vec![1.0, 0.0, 0.0, 0.0],
521 MessageKind::Regular,
522 "m",
523 )
524 .await
525 .unwrap();
526
527 let results = store
528 .search(
529 &[1.0, 0.0, 0.0, 0.0],
530 10,
531 Some(SearchFilter {
532 conversation_id: Some(cid1),
533 role: None,
534 }),
535 )
536 .await
537 .unwrap();
538 assert_eq!(results.len(), 1);
539 assert_eq!(results[0].conversation_id, cid1);
540 }
541
542 #[tokio::test]
543 async fn unique_constraint_on_message_and_model() {
544 let (sqlite, pool) = setup().await;
545 let cid = sqlite.create_conversation().await.unwrap();
546 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
547
548 let point_id1 = uuid::Uuid::new_v4().to_string();
549 sqlx::query(
550 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
551 VALUES (?, ?, ?, ?)",
552 )
553 .bind(msg_id)
554 .bind(&point_id1)
555 .bind(768_i64)
556 .bind("qwen3-embedding")
557 .execute(&pool)
558 .await
559 .unwrap();
560
561 let point_id2 = uuid::Uuid::new_v4().to_string();
562 let result = sqlx::query(
563 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
564 VALUES (?, ?, ?, ?)",
565 )
566 .bind(msg_id)
567 .bind(&point_id2)
568 .bind(768_i64)
569 .bind("qwen3-embedding")
570 .execute(&pool)
571 .await;
572
573 assert!(result.is_err());
574 }
575}