1pub use qdrant_client::qdrant::Filter;
2use sqlx::SqlitePool;
3
4use crate::error::MemoryError;
5use crate::qdrant_ops::QdrantOps;
6use crate::types::{ConversationId, MessageId};
7use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MessageKind {
12 Regular,
13 Summary,
14}
15
16impl MessageKind {
17 #[must_use]
18 pub fn is_summary(self) -> bool {
19 matches!(self, Self::Summary)
20 }
21}
22
23const COLLECTION_NAME: &str = "zeph_conversations";
24
25pub async fn ensure_qdrant_collection(
33 ops: &QdrantOps,
34 collection: &str,
35 vector_size: u64,
36) -> Result<(), Box<qdrant_client::QdrantError>> {
37 ops.ensure_collection(collection, vector_size).await
38}
39
40pub struct EmbeddingStore {
41 ops: Box<dyn VectorStore>,
42 collection: String,
43 pool: SqlitePool,
44}
45
46impl std::fmt::Debug for EmbeddingStore {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("EmbeddingStore")
49 .field("collection", &self.collection)
50 .finish_non_exhaustive()
51 }
52}
53
54#[derive(Debug)]
55pub struct SearchFilter {
56 pub conversation_id: Option<ConversationId>,
57 pub role: Option<String>,
58}
59
60#[derive(Debug)]
61pub struct SearchResult {
62 pub message_id: MessageId,
63 pub conversation_id: ConversationId,
64 pub score: f32,
65}
66
67impl EmbeddingStore {
68 pub fn new(url: &str, pool: SqlitePool) -> Result<Self, MemoryError> {
77 let ops = QdrantOps::new(url).map_err(MemoryError::Qdrant)?;
78
79 Ok(Self {
80 ops: Box::new(ops),
81 collection: COLLECTION_NAME.into(),
82 pool,
83 })
84 }
85
86 #[must_use]
87 pub fn with_store(store: Box<dyn VectorStore>, pool: SqlitePool) -> Self {
88 Self {
89 ops: store,
90 collection: COLLECTION_NAME.into(),
91 pool,
92 }
93 }
94
95 pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
103 self.ops
104 .ensure_collection(&self.collection, vector_size)
105 .await?;
106 Ok(())
107 }
108
109 pub async fn store(
117 &self,
118 message_id: MessageId,
119 conversation_id: ConversationId,
120 role: &str,
121 vector: Vec<f32>,
122 kind: MessageKind,
123 model: &str,
124 ) -> Result<String, MemoryError> {
125 let point_id = uuid::Uuid::new_v4().to_string();
126 let dimensions = i64::try_from(vector.len())?;
127
128 let payload = std::collections::HashMap::from([
129 ("message_id".to_owned(), serde_json::json!(message_id.0)),
130 (
131 "conversation_id".to_owned(),
132 serde_json::json!(conversation_id.0),
133 ),
134 ("role".to_owned(), serde_json::json!(role)),
135 (
136 "is_summary".to_owned(),
137 serde_json::json!(kind.is_summary()),
138 ),
139 ]);
140
141 let point = VectorPoint {
142 id: point_id.clone(),
143 vector,
144 payload,
145 };
146
147 self.ops.upsert(&self.collection, vec![point]).await?;
148
149 sqlx::query(
150 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
151 VALUES (?, ?, ?, ?) \
152 ON CONFLICT(message_id, model) DO UPDATE SET \
153 qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions",
154 )
155 .bind(message_id)
156 .bind(&point_id)
157 .bind(dimensions)
158 .bind(model)
159 .execute(&self.pool)
160 .await?;
161
162 Ok(point_id)
163 }
164
165 pub async fn search(
171 &self,
172 query_vector: &[f32],
173 limit: usize,
174 filter: Option<SearchFilter>,
175 ) -> Result<Vec<SearchResult>, MemoryError> {
176 let limit_u64 = u64::try_from(limit)?;
177
178 let vector_filter = filter.as_ref().and_then(|f| {
179 let mut must = Vec::new();
180 if let Some(cid) = f.conversation_id {
181 must.push(FieldCondition {
182 field: "conversation_id".into(),
183 value: FieldValue::Integer(cid.0),
184 });
185 }
186 if let Some(ref role) = f.role {
187 must.push(FieldCondition {
188 field: "role".into(),
189 value: FieldValue::Text(role.clone()),
190 });
191 }
192 if must.is_empty() {
193 None
194 } else {
195 Some(VectorFilter {
196 must,
197 must_not: vec![],
198 })
199 }
200 });
201
202 let results = self
203 .ops
204 .search(
205 &self.collection,
206 query_vector.to_vec(),
207 limit_u64,
208 vector_filter,
209 )
210 .await?;
211
212 let search_results = results
213 .into_iter()
214 .filter_map(|point| {
215 let message_id = MessageId(point.payload.get("message_id")?.as_i64()?);
216 let conversation_id =
217 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
218 Some(SearchResult {
219 message_id,
220 conversation_id,
221 score: point.score,
222 })
223 })
224 .collect();
225
226 Ok(search_results)
227 }
228
229 pub async fn ensure_named_collection(
235 &self,
236 name: &str,
237 vector_size: u64,
238 ) -> Result<(), MemoryError> {
239 self.ops.ensure_collection(name, vector_size).await?;
240 Ok(())
241 }
242
243 pub async fn store_to_collection(
251 &self,
252 collection: &str,
253 payload: serde_json::Value,
254 vector: Vec<f32>,
255 ) -> Result<String, MemoryError> {
256 let point_id = uuid::Uuid::new_v4().to_string();
257 let payload_map: std::collections::HashMap<String, serde_json::Value> =
258 serde_json::from_value(payload)?;
259 let point = VectorPoint {
260 id: point_id.clone(),
261 vector,
262 payload: payload_map,
263 };
264 self.ops.upsert(collection, vec![point]).await?;
265 Ok(point_id)
266 }
267
268 pub async fn search_collection(
274 &self,
275 collection: &str,
276 query_vector: &[f32],
277 limit: usize,
278 filter: Option<VectorFilter>,
279 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
280 let limit_u64 = u64::try_from(limit)?;
281 let results = self
282 .ops
283 .search(collection, query_vector.to_vec(), limit_u64, filter)
284 .await?;
285 Ok(results)
286 }
287
288 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
294 let row: (i64,) =
295 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
296 .bind(message_id)
297 .fetch_one(&self.pool)
298 .await?;
299
300 Ok(row.0 > 0)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::in_memory_store::InMemoryVectorStore;
308 use crate::sqlite::SqliteStore;
309
310 async fn setup() -> (SqliteStore, SqlitePool) {
311 let store = SqliteStore::new(":memory:").await.unwrap();
312 let pool = store.pool().clone();
313 (store, pool)
314 }
315
316 async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
317 let sqlite = SqliteStore::new(":memory:").await.unwrap();
318 let pool = sqlite.pool().clone();
319 let mem_store = Box::new(InMemoryVectorStore::new());
320 let embedding_store = EmbeddingStore::with_store(mem_store, pool);
321 embedding_store.ensure_collection(4).await.unwrap();
323 (embedding_store, sqlite)
324 }
325
326 #[tokio::test]
327 async fn has_embedding_returns_false_when_none() {
328 let (_store, pool) = setup().await;
329
330 let row: (i64,) =
331 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
332 .bind(999_i64)
333 .fetch_one(&pool)
334 .await
335 .unwrap();
336
337 assert_eq!(row.0, 0);
338 }
339
340 #[tokio::test]
341 async fn insert_and_query_embeddings_metadata() {
342 let (sqlite, pool) = setup().await;
343 let cid = sqlite.create_conversation().await.unwrap();
344 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
345
346 let point_id = uuid::Uuid::new_v4().to_string();
347 sqlx::query(
348 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
349 VALUES (?, ?, ?, ?)",
350 )
351 .bind(msg_id)
352 .bind(&point_id)
353 .bind(768_i64)
354 .bind("qwen3-embedding")
355 .execute(&pool)
356 .await
357 .unwrap();
358
359 let row: (i64,) =
360 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
361 .bind(msg_id)
362 .fetch_one(&pool)
363 .await
364 .unwrap();
365 assert_eq!(row.0, 1);
366 }
367
368 #[tokio::test]
369 async fn embedding_store_search_empty_returns_empty() {
370 let (store, _sqlite) = setup_with_store().await;
371 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
372 assert!(results.is_empty());
373 }
374
375 #[tokio::test]
376 async fn embedding_store_store_and_search() {
377 let (store, sqlite) = setup_with_store().await;
378 let cid = sqlite.create_conversation().await.unwrap();
379 let msg_id = sqlite
380 .save_message(cid, "user", "test message")
381 .await
382 .unwrap();
383
384 store
385 .store(
386 msg_id,
387 cid,
388 "user",
389 vec![1.0, 0.0, 0.0, 0.0],
390 MessageKind::Regular,
391 "test-model",
392 )
393 .await
394 .unwrap();
395
396 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
397 assert_eq!(results.len(), 1);
398 assert_eq!(results[0].message_id, msg_id);
399 assert_eq!(results[0].conversation_id, cid);
400 assert!((results[0].score - 1.0).abs() < 0.001);
401 }
402
403 #[tokio::test]
404 async fn embedding_store_has_embedding_false_for_unknown() {
405 let (store, sqlite) = setup_with_store().await;
406 let cid = sqlite.create_conversation().await.unwrap();
407 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
408 assert!(!store.has_embedding(msg_id).await.unwrap());
409 }
410
411 #[tokio::test]
412 async fn embedding_store_has_embedding_true_after_store() {
413 let (store, sqlite) = setup_with_store().await;
414 let cid = sqlite.create_conversation().await.unwrap();
415 let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
416
417 store
418 .store(
419 msg_id,
420 cid,
421 "user",
422 vec![0.0, 1.0, 0.0, 0.0],
423 MessageKind::Regular,
424 "test-model",
425 )
426 .await
427 .unwrap();
428
429 assert!(store.has_embedding(msg_id).await.unwrap());
430 }
431
432 #[tokio::test]
433 async fn embedding_store_search_with_conversation_filter() {
434 let (store, sqlite) = setup_with_store().await;
435 let cid1 = sqlite.create_conversation().await.unwrap();
436 let cid2 = sqlite.create_conversation().await.unwrap();
437 let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
438 let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
439
440 store
441 .store(
442 msg1,
443 cid1,
444 "user",
445 vec![1.0, 0.0, 0.0, 0.0],
446 MessageKind::Regular,
447 "m",
448 )
449 .await
450 .unwrap();
451 store
452 .store(
453 msg2,
454 cid2,
455 "user",
456 vec![1.0, 0.0, 0.0, 0.0],
457 MessageKind::Regular,
458 "m",
459 )
460 .await
461 .unwrap();
462
463 let results = store
464 .search(
465 &[1.0, 0.0, 0.0, 0.0],
466 10,
467 Some(SearchFilter {
468 conversation_id: Some(cid1),
469 role: None,
470 }),
471 )
472 .await
473 .unwrap();
474 assert_eq!(results.len(), 1);
475 assert_eq!(results[0].conversation_id, cid1);
476 }
477
478 #[tokio::test]
479 async fn unique_constraint_on_message_and_model() {
480 let (sqlite, pool) = setup().await;
481 let cid = sqlite.create_conversation().await.unwrap();
482 let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
483
484 let point_id1 = uuid::Uuid::new_v4().to_string();
485 sqlx::query(
486 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
487 VALUES (?, ?, ?, ?)",
488 )
489 .bind(msg_id)
490 .bind(&point_id1)
491 .bind(768_i64)
492 .bind("qwen3-embedding")
493 .execute(&pool)
494 .await
495 .unwrap();
496
497 let point_id2 = uuid::Uuid::new_v4().to_string();
498 let result = sqlx::query(
499 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions, model) \
500 VALUES (?, ?, ?, ?)",
501 )
502 .bind(msg_id)
503 .bind(&point_id2)
504 .bind(768_i64)
505 .bind("qwen3-embedding")
506 .execute(&pool)
507 .await;
508
509 assert!(result.is_err());
510 }
511}