Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Qdrant-backed embedding store for message vector search.
5//!
6//! [`EmbeddingStore`] owns a [`VectorStore`] implementation (Qdrant in production,
7//! [`crate::db_vector_store::DbVectorStore`] in tests) and exposes typed `embed` /
8//! `search` / `delete` operations used by [`crate::semantic::SemanticMemory`].
9//!
10//! Message vectors are stored in the `zeph_conversations` Qdrant collection with a
11//! payload that includes `message_id`, `conversation_id`, `role`, and `category`.
12
13pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24/// Distinguishes regular messages from summaries when storing embeddings.
25///
26/// The kind is encoded in the Qdrant payload so search filters can restrict
27/// results to one category.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum MessageKind {
30    /// A normal conversation message.
31    Regular,
32    /// A compression summary generated by the summarization subsystem.
33    Summary,
34}
35
36impl MessageKind {
37    #[must_use]
38    pub fn is_summary(self) -> bool {
39        matches!(self, Self::Summary)
40    }
41}
42
43const COLLECTION_NAME: &str = "zeph_conversations";
44
45/// Ensure a Qdrant collection exists with cosine distance vectors.
46///
47/// Idempotent: no-op if the collection already exists.
48///
49/// # Errors
50///
51/// Returns an error if Qdrant cannot be reached or collection creation fails.
52pub async fn ensure_qdrant_collection(
53    ops: &QdrantOps,
54    collection: &str,
55    vector_size: u64,
56) -> Result<(), Box<qdrant_client::QdrantError>> {
57    ops.ensure_collection(collection, vector_size).await
58}
59
60/// Typed wrapper over a [`VectorStore`] backend for conversation message embeddings.
61///
62/// Constructed via [`EmbeddingStore::new`] (Qdrant URL + optional API key) or
63/// [`EmbeddingStore::with_store`] (custom backend for testing).
64pub struct EmbeddingStore {
65    ops: Box<dyn VectorStore>,
66    collection: String,
67    pool: DbPool,
68}
69
70impl std::fmt::Debug for EmbeddingStore {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("EmbeddingStore")
73            .field("collection", &self.collection)
74            .finish_non_exhaustive()
75    }
76}
77
78/// Optional filters applied to a vector similarity search.
79#[derive(Debug)]
80pub struct SearchFilter {
81    /// Restrict results to a single conversation. `None` searches across all conversations.
82    pub conversation_id: Option<ConversationId>,
83    /// Restrict by message role (`"user"` / `"assistant"`). `None` returns all roles.
84    pub role: Option<String>,
85    /// Restrict by category payload field (category-aware memory, #2428).
86    /// When `Some`, Qdrant search is restricted to vectors with a matching `category` payload.
87    pub category: Option<String>,
88}
89
90/// A single result returned by [`EmbeddingStore::search`].
91#[derive(Debug)]
92pub struct SearchResult {
93    /// Database row ID of the matching message.
94    pub message_id: MessageId,
95    /// Conversation the message belongs to.
96    pub conversation_id: ConversationId,
97    /// Cosine similarity score in `[0, 1]`.
98    pub score: f32,
99}
100
101impl EmbeddingStore {
102    /// Create a new `EmbeddingStore` connected to the given Qdrant URL with optional API key.
103    ///
104    /// `api_key` is forwarded to [`QdrantOps::new`]. The `pool` is used for `SQLite` metadata
105    /// operations on the `embeddings_metadata` table (which must already exist via sqlx
106    /// migrations).
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if the Qdrant client cannot be created.
111    pub fn new(url: &str, api_key: Option<&str>, pool: DbPool) -> Result<Self, MemoryError> {
112        let ops = QdrantOps::new(url, api_key).map_err(MemoryError::Qdrant)?;
113
114        Ok(Self {
115            ops: Box::new(ops),
116            collection: COLLECTION_NAME.into(),
117            pool,
118        })
119    }
120
121    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
122    ///
123    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
124    #[must_use]
125    pub fn new_sqlite(pool: DbPool) -> Self {
126        let ops = DbVectorStore::new(pool.clone());
127        Self {
128            ops: Box::new(ops),
129            collection: COLLECTION_NAME.into(),
130            pool,
131        }
132    }
133
134    #[must_use]
135    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
136        Self {
137            ops: store,
138            collection: COLLECTION_NAME.into(),
139            pool,
140        }
141    }
142
143    pub async fn health_check(&self) -> bool {
144        self.ops.health_check().await.unwrap_or(false)
145    }
146
147    /// Ensure the collection exists in Qdrant with the given vector size.
148    ///
149    /// Idempotent: no-op if the collection already exists.
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if Qdrant cannot be reached or collection creation fails.
154    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
155        self.ops
156            .ensure_collection(&self.collection, vector_size)
157            .await?;
158        // Create keyword indexes for the fields used in filtered recall so Qdrant can satisfy
159        // filter conditions in O(log n) instead of scanning all payload documents.
160        self.ops
161            .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
162            .await?;
163        Ok(())
164    }
165
166    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
167    ///
168    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
169    /// alongside the standard fields. This allows filtering and scoring by tool context
170    /// without corrupting the embedding vector with text prefixes.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
175    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
176    pub async fn store_with_tool_context(
177        &self,
178        message_id: MessageId,
179        conversation_id: ConversationId,
180        role: &str,
181        vector: Vec<f32>,
182        kind: MessageKind,
183        model: &str,
184        chunk_index: u32,
185        tool_name: &str,
186        exit_code: Option<i32>,
187        timestamp: Option<&str>,
188    ) -> Result<String, MemoryError> {
189        let point_id = uuid::Uuid::new_v4().to_string();
190        let dimensions = i64::try_from(vector.len())?;
191
192        let mut payload = std::collections::HashMap::from([
193            ("message_id".to_owned(), serde_json::json!(message_id.0)),
194            (
195                "conversation_id".to_owned(),
196                serde_json::json!(conversation_id.0),
197            ),
198            ("role".to_owned(), serde_json::json!(role)),
199            (
200                "is_summary".to_owned(),
201                serde_json::json!(kind.is_summary()),
202            ),
203            ("tool_name".to_owned(), serde_json::json!(tool_name)),
204        ]);
205        if let Some(code) = exit_code {
206            payload.insert("exit_code".to_owned(), serde_json::json!(code));
207        }
208        if let Some(ts) = timestamp {
209            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
210        }
211
212        let point = VectorPoint {
213            id: point_id.clone(),
214            vector,
215            payload,
216        };
217
218        self.ops.upsert(&self.collection, vec![point]).await?;
219
220        let chunk_index_i64 = i64::from(chunk_index);
221        zeph_db::query(sql!(
222            "INSERT INTO embeddings_metadata \
223             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
224             VALUES (?, ?, ?, ?, ?) \
225             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
226             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
227        ))
228        .bind(message_id)
229        .bind(chunk_index_i64)
230        .bind(&point_id)
231        .bind(dimensions)
232        .bind(model)
233        .execute(&self.pool)
234        .await?;
235
236        Ok(point_id)
237    }
238
239    /// Store a vector in Qdrant and persist metadata to `SQLite`.
240    ///
241    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
242    /// when a long message is split into multiple embeddings.
243    ///
244    /// Returns the UUID of the newly created Qdrant point.
245    ///
246    /// # Errors
247    ///
248    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
249    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
250    pub async fn store(
251        &self,
252        message_id: MessageId,
253        conversation_id: ConversationId,
254        role: &str,
255        vector: Vec<f32>,
256        kind: MessageKind,
257        model: &str,
258        chunk_index: u32,
259    ) -> Result<String, MemoryError> {
260        let point_id = uuid::Uuid::new_v4().to_string();
261        let dimensions = i64::try_from(vector.len())?;
262
263        let payload = std::collections::HashMap::from([
264            ("message_id".to_owned(), serde_json::json!(message_id.0)),
265            (
266                "conversation_id".to_owned(),
267                serde_json::json!(conversation_id.0),
268            ),
269            ("role".to_owned(), serde_json::json!(role)),
270            (
271                "is_summary".to_owned(),
272                serde_json::json!(kind.is_summary()),
273            ),
274        ]);
275
276        let point = VectorPoint {
277            id: point_id.clone(),
278            vector,
279            payload,
280        };
281
282        self.ops.upsert(&self.collection, vec![point]).await?;
283
284        let chunk_index_i64 = i64::from(chunk_index);
285        zeph_db::query(sql!(
286            "INSERT INTO embeddings_metadata \
287             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
288             VALUES (?, ?, ?, ?, ?) \
289             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
290             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
291        ))
292        .bind(message_id)
293        .bind(chunk_index_i64)
294        .bind(&point_id)
295        .bind(dimensions)
296        .bind(model)
297        .execute(&self.pool)
298        .await?;
299
300        Ok(point_id)
301    }
302
303    /// Store a vector with an optional category tag in the Qdrant payload.
304    ///
305    /// Identical to [`Self::store`] but adds a `category` field to the payload when provided.
306    /// Used by category-aware memory (#2428) to enable category-filtered recall.
307    ///
308    /// Note: when `category` is `None` no `category` field is written to the Qdrant payload.
309    /// Memories stored before category-aware recall was enabled therefore won't match a
310    /// category filter — this is intentional (no silent false-positives), but a backfill
311    /// pass is needed if retrospective categorization is desired.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
316    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
317    pub async fn store_with_category(
318        &self,
319        message_id: MessageId,
320        conversation_id: ConversationId,
321        role: &str,
322        vector: Vec<f32>,
323        kind: MessageKind,
324        model: &str,
325        chunk_index: u32,
326        category: Option<&str>,
327    ) -> Result<String, MemoryError> {
328        let point_id = uuid::Uuid::new_v4().to_string();
329        let dimensions = i64::try_from(vector.len())?;
330
331        let mut payload = std::collections::HashMap::from([
332            ("message_id".to_owned(), serde_json::json!(message_id.0)),
333            (
334                "conversation_id".to_owned(),
335                serde_json::json!(conversation_id.0),
336            ),
337            ("role".to_owned(), serde_json::json!(role)),
338            (
339                "is_summary".to_owned(),
340                serde_json::json!(kind.is_summary()),
341            ),
342        ]);
343        if let Some(cat) = category {
344            payload.insert("category".to_owned(), serde_json::json!(cat));
345        }
346
347        let point = VectorPoint {
348            id: point_id.clone(),
349            vector,
350            payload,
351        };
352
353        self.ops.upsert(&self.collection, vec![point]).await?;
354
355        let chunk_index_i64 = i64::from(chunk_index);
356        zeph_db::query(sql!(
357            "INSERT INTO embeddings_metadata \
358             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
359             VALUES (?, ?, ?, ?, ?) \
360             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
361             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
362        ))
363        .bind(message_id)
364        .bind(chunk_index_i64)
365        .bind(&point_id)
366        .bind(dimensions)
367        .bind(model)
368        .execute(&self.pool)
369        .await?;
370
371        Ok(point_id)
372    }
373
374    /// Search for similar vectors in Qdrant, returning up to `limit` results.
375    ///
376    /// # Errors
377    ///
378    /// Returns an error if the Qdrant search fails.
379    pub async fn search(
380        &self,
381        query_vector: &[f32],
382        limit: usize,
383        filter: Option<SearchFilter>,
384    ) -> Result<Vec<SearchResult>, MemoryError> {
385        let limit_u64 = u64::try_from(limit)?;
386
387        let vector_filter = filter.as_ref().and_then(|f| {
388            let mut must = Vec::new();
389            if let Some(cid) = f.conversation_id {
390                must.push(FieldCondition {
391                    field: "conversation_id".into(),
392                    value: FieldValue::Integer(cid.0),
393                });
394            }
395            if let Some(ref role) = f.role {
396                must.push(FieldCondition {
397                    field: "role".into(),
398                    value: FieldValue::Text(role.clone()),
399                });
400            }
401            if let Some(ref category) = f.category {
402                must.push(FieldCondition {
403                    field: "category".into(),
404                    value: FieldValue::Text(category.clone()),
405                });
406            }
407            if must.is_empty() {
408                None
409            } else {
410                Some(VectorFilter {
411                    must,
412                    must_not: vec![],
413                })
414            }
415        });
416
417        let results = self
418            .ops
419            .search(
420                &self.collection,
421                query_vector.to_vec(),
422                limit_u64,
423                vector_filter,
424            )
425            .await?;
426
427        // Deduplicate by message_id, keeping the chunk with the highest score.
428        // A single message may produce multiple Qdrant points (one per chunk).
429        let mut best: std::collections::HashMap<MessageId, SearchResult> =
430            std::collections::HashMap::new();
431        for point in results {
432            let Some(message_id) = point
433                .payload
434                .get("message_id")
435                .and_then(serde_json::Value::as_i64)
436            else {
437                continue;
438            };
439            let Some(conversation_id) = point
440                .payload
441                .get("conversation_id")
442                .and_then(serde_json::Value::as_i64)
443            else {
444                continue;
445            };
446            let message_id = MessageId(message_id);
447            let entry = best.entry(message_id).or_insert(SearchResult {
448                message_id,
449                conversation_id: ConversationId(conversation_id),
450                score: f32::NEG_INFINITY,
451            });
452            if point.score > entry.score {
453                entry.score = point.score;
454            }
455        }
456
457        let mut search_results: Vec<SearchResult> = best.into_values().collect();
458        search_results.sort_by(|a, b| {
459            b.score
460                .partial_cmp(&a.score)
461                .unwrap_or(std::cmp::Ordering::Equal)
462        });
463        search_results.truncate(limit);
464
465        Ok(search_results)
466    }
467
468    /// Check whether a named collection exists in the vector store.
469    ///
470    /// # Errors
471    ///
472    /// Returns an error if the store backend cannot be reached.
473    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
474        self.ops.collection_exists(name).await.map_err(Into::into)
475    }
476
477    /// Ensure a named collection exists in Qdrant with the given vector size.
478    ///
479    /// # Errors
480    ///
481    /// Returns an error if Qdrant cannot be reached or collection creation fails.
482    pub async fn ensure_named_collection(
483        &self,
484        name: &str,
485        vector_size: u64,
486    ) -> Result<(), MemoryError> {
487        self.ops.ensure_collection(name, vector_size).await?;
488        Ok(())
489    }
490
491    /// Store a vector in a named Qdrant collection with arbitrary payload.
492    ///
493    /// Returns the UUID of the newly created point.
494    ///
495    /// # Errors
496    ///
497    /// Returns an error if the Qdrant upsert fails.
498    pub async fn store_to_collection(
499        &self,
500        collection: &str,
501        payload: serde_json::Value,
502        vector: Vec<f32>,
503    ) -> Result<String, MemoryError> {
504        let point_id = uuid::Uuid::new_v4().to_string();
505        let payload_map: std::collections::HashMap<String, serde_json::Value> =
506            serde_json::from_value(payload)?;
507        let point = VectorPoint {
508            id: point_id.clone(),
509            vector,
510            payload: payload_map,
511        };
512        self.ops.upsert(collection, vec![point]).await?;
513        Ok(point_id)
514    }
515
516    /// Upsert a vector into a named collection, reusing an existing point ID.
517    ///
518    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
519    ///
520    /// # Errors
521    ///
522    /// Returns an error if the Qdrant upsert fails.
523    pub async fn upsert_to_collection(
524        &self,
525        collection: &str,
526        point_id: &str,
527        payload: serde_json::Value,
528        vector: Vec<f32>,
529    ) -> Result<(), MemoryError> {
530        let payload_map: std::collections::HashMap<String, serde_json::Value> =
531            serde_json::from_value(payload)?;
532        let point = VectorPoint {
533            id: point_id.to_owned(),
534            vector,
535            payload: payload_map,
536        };
537        self.ops.upsert(collection, vec![point]).await?;
538        Ok(())
539    }
540
541    /// Search a named Qdrant collection, returning scored points with payloads.
542    ///
543    /// # Errors
544    ///
545    /// Returns an error if the Qdrant search fails.
546    pub async fn search_collection(
547        &self,
548        collection: &str,
549        query_vector: &[f32],
550        limit: usize,
551        filter: Option<VectorFilter>,
552    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
553        let limit_u64 = u64::try_from(limit)?;
554        let results = self
555            .ops
556            .search(collection, query_vector.to_vec(), limit_u64, filter)
557            .await?;
558        Ok(results)
559    }
560
561    /// Enumerate `(point_id, entity_id)` pairs for all points in `collection` that carry
562    /// an `entity_id_str` payload field.
563    ///
564    /// `entity_id_str` is a string mirror of the i64 `entity_id` written alongside the numeric
565    /// field at embedding time. The scroll API only surfaces string-typed payload values, so a
566    /// parallel string field is necessary for enumeration. Points missing `entity_id_str`
567    /// (written before this field was added) are silently skipped — they will gain the field on
568    /// the next `merge_entity` or `store_entity_embedding` call.
569    ///
570    /// # Errors
571    ///
572    /// Returns an error if the underlying scroll operation fails.
573    pub async fn scroll_all_entity_ids(
574        &self,
575        collection: &str,
576    ) -> Result<Vec<(String, i64)>, MemoryError> {
577        let rows = self
578            .ops
579            .scroll_all_with_point_ids(collection, "entity_id_str")
580            .await?;
581        let mut out = Vec::with_capacity(rows.len());
582        for (point_id, fields) in rows {
583            let Some(s) = fields.get("entity_id_str") else {
584                continue;
585            };
586            if let Ok(id) = s.parse::<i64>() {
587                out.push((point_id, id));
588            } else {
589                tracing::debug!(point_id, value = %s, "entity_id_str unparseable, skipping");
590            }
591        }
592        Ok(out)
593    }
594
595    /// Delete a set of points from a named collection by their Qdrant point IDs.
596    ///
597    /// This is a thin wrapper over [`VectorStore::delete_by_ids`] for use by
598    /// the stale-embedding cleanup path in `community.rs`.
599    ///
600    /// # Errors
601    ///
602    /// Returns an error if the underlying delete operation fails.
603    pub async fn delete_from_collection(
604        &self,
605        collection: &str,
606        ids: Vec<String>,
607    ) -> Result<(), MemoryError> {
608        if ids.is_empty() {
609            return Ok(());
610        }
611        self.ops.delete_by_ids(collection, ids).await?;
612        Ok(())
613    }
614
615    /// Retrieve raw vectors for the given Qdrant point IDs from `collection`.
616    ///
617    /// Returns a map of `point_id → embedding`. Missing ids are silently dropped.
618    /// Returns an empty map when the backend does not support vector retrieval
619    /// (e.g. `DbVectorStore` / `InMemoryVectorStore` without an override).
620    ///
621    /// # Errors
622    ///
623    /// Returns an error if the underlying store returns a non-`Unsupported` error.
624    pub async fn get_vectors_from_collection(
625        &self,
626        collection: &str,
627        point_ids: &[String],
628    ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
629        if point_ids.is_empty() {
630            return Ok(std::collections::HashMap::new());
631        }
632        match self.ops.get_points(collection, point_ids.to_vec()).await {
633            Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
634            Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
635            Err(e) => Err(MemoryError::VectorStore(e)),
636        }
637    }
638
639    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
640    ///
641    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
642    ///
643    /// # Errors
644    ///
645    /// Returns an error if the `SQLite` query fails.
646    pub async fn get_vectors(
647        &self,
648        ids: &[MessageId],
649    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
650        if ids.is_empty() {
651            return Ok(std::collections::HashMap::new());
652        }
653
654        let placeholders = zeph_db::placeholder_list(1, ids.len());
655        let query = format!(
656            "SELECT em.message_id, vp.vector \
657             FROM embeddings_metadata em \
658             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
659             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
660        );
661        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
662        for &id in ids {
663            q = q.bind(id);
664        }
665
666        let rows = q.fetch_all(&self.pool).await?;
667
668        let map = rows
669            .into_iter()
670            .filter_map(|(msg_id, blob)| {
671                if blob.len() % 4 != 0 {
672                    return None;
673                }
674                let vec: Vec<f32> = blob
675                    .chunks_exact(4)
676                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
677                    .collect();
678                Some((msg_id, vec))
679            })
680            .collect();
681
682        Ok(map)
683    }
684
685    /// Fetch embeddings for the given message IDs from the configured vector store.
686    ///
687    /// Resolves `message_id → qdrant_point_id` via `embeddings_metadata` (filtering to
688    /// `chunk_index = 0` so each message yields at most one vector), then retrieves the
689    /// vectors from the underlying [`VectorStore`].
690    ///
691    /// Returns a map from [`MessageId`] to embedding vector. Messages without an
692    /// `embeddings_metadata` row, or whose vector cannot be retrieved, are silently dropped.
693    /// When the backend returns [`crate::VectorStoreError::Unsupported`], an empty map is
694    /// returned without error (matches [`Self::get_vectors_from_collection`] semantics).
695    ///
696    /// # Errors
697    ///
698    /// Returns an error if the `SQLite` metadata query or vector store retrieval fails.
699    pub async fn get_vectors_for_messages(
700        &self,
701        ids: &[MessageId],
702    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
703        if ids.is_empty() {
704            return Ok(std::collections::HashMap::new());
705        }
706
707        let placeholders = zeph_db::placeholder_list(1, ids.len());
708        let query = format!(
709            "SELECT message_id, qdrant_point_id \
710             FROM embeddings_metadata \
711             WHERE message_id IN ({placeholders}) AND chunk_index = 0"
712        );
713        let mut q = zeph_db::query_as::<_, (MessageId, String)>(&query);
714        for &id in ids {
715            q = q.bind(id);
716        }
717        let rows: Vec<(MessageId, String)> = q.fetch_all(&self.pool).await?;
718
719        if rows.is_empty() {
720            return Ok(std::collections::HashMap::new());
721        }
722
723        // Build reverse map: point_id → message_id for result translation.
724        let mut point_to_msg: std::collections::HashMap<String, MessageId> =
725            std::collections::HashMap::with_capacity(rows.len());
726        let point_ids: Vec<String> = rows
727            .into_iter()
728            .map(|(msg_id, point_id)| {
729                point_to_msg.insert(point_id.clone(), msg_id);
730                point_id
731            })
732            .collect();
733
734        let points = match self.ops.get_points(&self.collection, point_ids).await {
735            Ok(pts) => pts,
736            Err(crate::VectorStoreError::Unsupported(_)) => {
737                return Ok(std::collections::HashMap::new());
738            }
739            Err(e) => return Err(MemoryError::VectorStore(e)),
740        };
741
742        let result = points
743            .into_iter()
744            .filter_map(|p| {
745                let msg_id = point_to_msg.get(&p.id).copied()?;
746                Some((msg_id, p.vector))
747            })
748            .collect();
749
750        Ok(result)
751    }
752
753    /// Delete all Qdrant vectors associated with the given message IDs.
754    ///
755    /// Resolves `message_id → qdrant_point_id` via the `embeddings_metadata` table,
756    /// then calls the underlying vector store's `delete_by_ids`. The
757    /// `embeddings_metadata` rows are **not** removed here — the `SQLite` CASCADE on
758    /// `messages` handles that when the rows are hard-deleted later.
759    ///
760    /// Returns the number of Qdrant point IDs targeted for deletion (may be less than
761    /// `ids.len()` when some messages have no embeddings).
762    ///
763    /// # Errors
764    ///
765    /// Returns [`MemoryError`] if the `SQLite` query or the vector store delete fails.
766    pub async fn delete_by_message_ids(&self, ids: &[MessageId]) -> Result<usize, MemoryError> {
767        if ids.is_empty() {
768            return Ok(0);
769        }
770
771        let placeholders = zeph_db::placeholder_list(1, ids.len());
772        let query = format!(
773            "SELECT qdrant_point_id FROM embeddings_metadata WHERE message_id IN ({placeholders})"
774        );
775        let mut q = zeph_db::query_as::<_, (String,)>(&query);
776        for &id in ids {
777            q = q.bind(id);
778        }
779        let rows: Vec<(String,)> = q.fetch_all(&self.pool).await?;
780
781        let point_ids: Vec<String> = rows.into_iter().map(|(id,)| id).collect();
782        let count = point_ids.len();
783
784        if !point_ids.is_empty() {
785            self.ops.delete_by_ids(&self.collection, point_ids).await?;
786        }
787
788        Ok(count)
789    }
790
791    /// Check whether an embedding already exists for the given message ID.
792    ///
793    /// # Errors
794    ///
795    /// Returns an error if the `SQLite` query fails.
796    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
797        let row: (i64,) = zeph_db::query_as(sql!(
798            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
799        ))
800        .bind(message_id)
801        .fetch_one(&self.pool)
802        .await?;
803
804        Ok(row.0 > 0)
805    }
806
807    /// Check whether a Qdrant embedding for `entity_name` is current by comparing the
808    /// Qdrant-side epoch against the epoch stored in `graph_entities`.
809    ///
810    /// Returns `true` if the Qdrant embedding is up-to-date or if the entity no longer
811    /// exists in `SQLite` (embedding should be cleaned up separately).
812    ///
813    /// # Errors
814    ///
815    /// Returns an error if the `SQLite` query fails.
816    pub async fn is_epoch_current(
817        &self,
818        entity_name: &str,
819        qdrant_epoch: u64,
820    ) -> Result<bool, MemoryError> {
821        let row: Option<(i64,)> = zeph_db::query_as(sql!(
822            "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
823        ))
824        .bind(entity_name)
825        .fetch_optional(&self.pool)
826        .await?;
827
828        match row {
829            None => Ok(true), // entity deleted; Qdrant point is orphaned, not stale per epoch
830            Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
831        }
832    }
833}
834
835#[cfg(test)]
836mod tests {
837    use super::*;
838    use crate::in_memory_store::InMemoryVectorStore;
839    use crate::store::SqliteStore;
840
841    async fn setup() -> (SqliteStore, DbPool) {
842        let store = SqliteStore::new(":memory:").await.unwrap();
843        let pool = store.pool().clone();
844        (store, pool)
845    }
846
847    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
848        let sqlite = SqliteStore::new(":memory:").await.unwrap();
849        let pool = sqlite.pool().clone();
850        let mem_store = Box::new(InMemoryVectorStore::new());
851        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
852        // Create collection first
853        embedding_store.ensure_collection(4).await.unwrap();
854        (embedding_store, sqlite)
855    }
856
857    #[tokio::test]
858    async fn has_embedding_returns_false_when_none() {
859        let (_store, pool) = setup().await;
860
861        let row: (i64,) = zeph_db::query_as(sql!(
862            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
863        ))
864        .bind(999_i64)
865        .fetch_one(&pool)
866        .await
867        .unwrap();
868
869        assert_eq!(row.0, 0);
870    }
871
872    #[tokio::test]
873    async fn insert_and_query_embeddings_metadata() {
874        let (sqlite, pool) = setup().await;
875        let cid = sqlite.create_conversation().await.unwrap();
876        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
877
878        let point_id = uuid::Uuid::new_v4().to_string();
879        zeph_db::query(sql!(
880            "INSERT INTO embeddings_metadata \
881             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
882             VALUES (?, ?, ?, ?, ?)"
883        ))
884        .bind(msg_id)
885        .bind(0_i64)
886        .bind(&point_id)
887        .bind(768_i64)
888        .bind("qwen3-embedding")
889        .execute(&pool)
890        .await
891        .unwrap();
892
893        let row: (i64,) = zeph_db::query_as(sql!(
894            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
895        ))
896        .bind(msg_id)
897        .fetch_one(&pool)
898        .await
899        .unwrap();
900        assert_eq!(row.0, 1);
901    }
902
903    #[tokio::test]
904    async fn embedding_store_search_empty_returns_empty() {
905        let (store, _sqlite) = setup_with_store().await;
906        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
907        assert!(results.is_empty());
908    }
909
910    #[tokio::test]
911    async fn embedding_store_store_and_search() {
912        let (store, sqlite) = setup_with_store().await;
913        let cid = sqlite.create_conversation().await.unwrap();
914        let msg_id = sqlite
915            .save_message(cid, "user", "test message")
916            .await
917            .unwrap();
918
919        store
920            .store(
921                msg_id,
922                cid,
923                "user",
924                vec![1.0, 0.0, 0.0, 0.0],
925                MessageKind::Regular,
926                "test-model",
927                0,
928            )
929            .await
930            .unwrap();
931
932        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
933        assert_eq!(results.len(), 1);
934        assert_eq!(results[0].message_id, msg_id);
935        assert_eq!(results[0].conversation_id, cid);
936        assert!((results[0].score - 1.0).abs() < 0.001);
937    }
938
939    #[tokio::test]
940    async fn embedding_store_has_embedding_false_for_unknown() {
941        let (store, sqlite) = setup_with_store().await;
942        let cid = sqlite.create_conversation().await.unwrap();
943        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
944        assert!(!store.has_embedding(msg_id).await.unwrap());
945    }
946
947    #[tokio::test]
948    async fn embedding_store_has_embedding_true_after_store() {
949        let (store, sqlite) = setup_with_store().await;
950        let cid = sqlite.create_conversation().await.unwrap();
951        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
952
953        store
954            .store(
955                msg_id,
956                cid,
957                "user",
958                vec![0.0, 1.0, 0.0, 0.0],
959                MessageKind::Regular,
960                "test-model",
961                0,
962            )
963            .await
964            .unwrap();
965
966        assert!(store.has_embedding(msg_id).await.unwrap());
967    }
968
969    #[tokio::test]
970    async fn embedding_store_search_with_conversation_filter() {
971        let (store, sqlite) = setup_with_store().await;
972        let cid1 = sqlite.create_conversation().await.unwrap();
973        let cid2 = sqlite.create_conversation().await.unwrap();
974        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
975        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
976
977        store
978            .store(
979                msg1,
980                cid1,
981                "user",
982                vec![1.0, 0.0, 0.0, 0.0],
983                MessageKind::Regular,
984                "m",
985                0,
986            )
987            .await
988            .unwrap();
989        store
990            .store(
991                msg2,
992                cid2,
993                "user",
994                vec![1.0, 0.0, 0.0, 0.0],
995                MessageKind::Regular,
996                "m",
997                0,
998            )
999            .await
1000            .unwrap();
1001
1002        let results = store
1003            .search(
1004                &[1.0, 0.0, 0.0, 0.0],
1005                10,
1006                Some(SearchFilter {
1007                    conversation_id: Some(cid1),
1008                    role: None,
1009                    category: None,
1010                }),
1011            )
1012            .await
1013            .unwrap();
1014        assert_eq!(results.len(), 1);
1015        assert_eq!(results[0].conversation_id, cid1);
1016    }
1017
1018    #[tokio::test]
1019    async fn unique_constraint_on_message_chunk_and_model() {
1020        let (sqlite, pool) = setup().await;
1021        let cid = sqlite.create_conversation().await.unwrap();
1022        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1023
1024        let point_id1 = uuid::Uuid::new_v4().to_string();
1025        zeph_db::query(sql!(
1026            "INSERT INTO embeddings_metadata \
1027             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1028             VALUES (?, ?, ?, ?, ?)"
1029        ))
1030        .bind(msg_id)
1031        .bind(0_i64)
1032        .bind(&point_id1)
1033        .bind(768_i64)
1034        .bind("qwen3-embedding")
1035        .execute(&pool)
1036        .await
1037        .unwrap();
1038
1039        // Same (message_id, chunk_index, model) — must fail.
1040        let point_id2 = uuid::Uuid::new_v4().to_string();
1041        let result = zeph_db::query(sql!(
1042            "INSERT INTO embeddings_metadata \
1043             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1044             VALUES (?, ?, ?, ?, ?)"
1045        ))
1046        .bind(msg_id)
1047        .bind(0_i64)
1048        .bind(&point_id2)
1049        .bind(768_i64)
1050        .bind("qwen3-embedding")
1051        .execute(&pool)
1052        .await;
1053        assert!(result.is_err());
1054
1055        // Different chunk_index — must succeed.
1056        let point_id3 = uuid::Uuid::new_v4().to_string();
1057        zeph_db::query(sql!(
1058            "INSERT INTO embeddings_metadata \
1059             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1060             VALUES (?, ?, ?, ?, ?)"
1061        ))
1062        .bind(msg_id)
1063        .bind(1_i64)
1064        .bind(&point_id3)
1065        .bind(768_i64)
1066        .bind("qwen3-embedding")
1067        .execute(&pool)
1068        .await
1069        .unwrap();
1070    }
1071
1072    #[tokio::test]
1073    async fn get_vectors_for_messages_returns_correct_vectors() {
1074        let (store, sqlite) = setup_with_store().await;
1075        let cid = sqlite.create_conversation().await.unwrap();
1076        let msg1 = sqlite.save_message(cid, "user", "hello").await.unwrap();
1077        let msg2 = sqlite.save_message(cid, "user", "world").await.unwrap();
1078
1079        store
1080            .store(
1081                msg1,
1082                cid,
1083                "user",
1084                vec![1.0, 0.0, 0.0, 0.0],
1085                MessageKind::Regular,
1086                "m",
1087                0,
1088            )
1089            .await
1090            .unwrap();
1091        store
1092            .store(
1093                msg2,
1094                cid,
1095                "user",
1096                vec![0.0, 1.0, 0.0, 0.0],
1097                MessageKind::Regular,
1098                "m",
1099                0,
1100            )
1101            .await
1102            .unwrap();
1103
1104        let result = store.get_vectors_for_messages(&[msg1, msg2]).await.unwrap();
1105        assert_eq!(result.len(), 2);
1106        let v1 = result.get(&msg1).unwrap();
1107        let v2 = result.get(&msg2).unwrap();
1108        assert!((v1[0] - 1.0).abs() < f32::EPSILON);
1109        assert!((v2[1] - 1.0).abs() < f32::EPSILON);
1110    }
1111
1112    #[tokio::test]
1113    async fn get_vectors_for_messages_missing_id_is_dropped() {
1114        let (store, sqlite) = setup_with_store().await;
1115        let cid = sqlite.create_conversation().await.unwrap();
1116        let msg1 = sqlite.save_message(cid, "user", "present").await.unwrap();
1117        let msg_absent = MessageId(99_999);
1118
1119        store
1120            .store(
1121                msg1,
1122                cid,
1123                "user",
1124                vec![1.0, 0.0, 0.0, 0.0],
1125                MessageKind::Regular,
1126                "m",
1127                0,
1128            )
1129            .await
1130            .unwrap();
1131
1132        let result = store
1133            .get_vectors_for_messages(&[msg1, msg_absent])
1134            .await
1135            .unwrap();
1136        assert_eq!(result.len(), 1);
1137        assert!(result.contains_key(&msg1));
1138        assert!(!result.contains_key(&msg_absent));
1139    }
1140
1141    #[tokio::test]
1142    async fn get_vectors_for_messages_empty_input() {
1143        let (store, _sqlite) = setup_with_store().await;
1144        let result = store.get_vectors_for_messages(&[]).await.unwrap();
1145        assert!(result.is_empty());
1146    }
1147
1148    #[tokio::test]
1149    async fn get_vectors_for_messages_chunk_index_0_only() {
1150        // Store chunk_index=0 and chunk_index=1; only chunk_index=0 should be returned.
1151        let (store, sqlite) = setup_with_store().await;
1152        let cid = sqlite.create_conversation().await.unwrap();
1153        let msg = sqlite.save_message(cid, "user", "chunked").await.unwrap();
1154
1155        store
1156            .store(
1157                msg,
1158                cid,
1159                "user",
1160                vec![1.0, 0.0, 0.0, 0.0],
1161                MessageKind::Regular,
1162                "m",
1163                0,
1164            )
1165            .await
1166            .unwrap();
1167        store
1168            .store(
1169                msg,
1170                cid,
1171                "user",
1172                vec![0.0, 0.0, 1.0, 0.0],
1173                MessageKind::Regular,
1174                "m",
1175                1,
1176            )
1177            .await
1178            .unwrap();
1179
1180        let result = store.get_vectors_for_messages(&[msg]).await.unwrap();
1181        assert_eq!(result.len(), 1);
1182        // Must be the chunk_index=0 vector
1183        let v = result.get(&msg).unwrap();
1184        assert!(
1185            (v[0] - 1.0).abs() < f32::EPSILON,
1186            "expected chunk_index=0 vector"
1187        );
1188    }
1189
1190    /// `delete_by_message_ids` resolves `message_id → qdrant_point_id` via
1191    /// `embeddings_metadata` and deletes the matching vectors.
1192    ///
1193    /// Verifies: (a) the correct point id is targeted, (b) `embeddings_metadata`
1194    /// rows are NOT removed (CASCADE handles that on hard-delete later), and (c) the
1195    /// method returns the number of point IDs found.
1196    #[tokio::test]
1197    async fn embedding_store_delete_by_message_ids_resolves_via_metadata() {
1198        let (store, sqlite) = setup_with_store().await;
1199        let cid = sqlite.create_conversation().await.unwrap();
1200        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1201
1202        // Store a vector so embeddings_metadata gets a row.
1203        store
1204            .store(
1205                msg_id,
1206                cid,
1207                "user",
1208                vec![1.0, 0.0, 0.0, 0.0],
1209                MessageKind::Regular,
1210                "test-model",
1211                0,
1212            )
1213            .await
1214            .unwrap();
1215
1216        // Confirm the metadata row exists before deletion.
1217        assert!(store.has_embedding(msg_id).await.unwrap());
1218
1219        // Delete by message id — must succeed and return 1 (one point id resolved).
1220        let deleted = store.delete_by_message_ids(&[msg_id]).await.unwrap();
1221        assert_eq!(deleted, 1, "one point id should have been targeted");
1222
1223        // embeddings_metadata rows must still be present (CASCADE removes them later).
1224        let pool = sqlite.pool().clone();
1225        let row: (i64,) = zeph_db::query_as(sql!(
1226            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
1227        ))
1228        .bind(msg_id)
1229        .fetch_one(&pool)
1230        .await
1231        .unwrap();
1232        assert_eq!(
1233            row.0, 1,
1234            "embeddings_metadata row must survive delete_by_message_ids"
1235        );
1236    }
1237
1238    /// `delete_by_message_ids` is a no-op when the slice is empty.
1239    #[tokio::test]
1240    async fn embedding_store_delete_by_message_ids_empty_slice_is_noop() {
1241        let (store, _sqlite) = setup_with_store().await;
1242        let deleted = store.delete_by_message_ids(&[]).await.unwrap();
1243        assert_eq!(deleted, 0);
1244    }
1245}