Skip to main content

zeph_memory/
embedding_store.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Qdrant-backed embedding store for message vector search.
5//!
6//! [`EmbeddingStore`] owns a [`VectorStore`] implementation (Qdrant in production,
7//! [`crate::db_vector_store::DbVectorStore`] in tests) and exposes typed `embed` /
8//! `search` / `delete` operations used by [`crate::semantic::SemanticMemory`].
9//!
10//! Message vectors are stored in the `zeph_conversations` Qdrant collection with a
11//! payload that includes `message_id`, `conversation_id`, `role`, and `category`.
12
13pub use qdrant_client::qdrant::Filter;
14use zeph_db::DbPool;
15#[allow(unused_imports)]
16use zeph_db::sql;
17
18use crate::db_vector_store::DbVectorStore;
19use crate::error::MemoryError;
20use crate::qdrant_ops::QdrantOps;
21use crate::types::{ConversationId, MessageId};
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter, VectorPoint, VectorStore};
23
24/// Distinguishes regular messages from summaries when storing embeddings.
25///
26/// The kind is encoded in the Qdrant payload so search filters can restrict
27/// results to one category.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[non_exhaustive]
30pub enum MessageKind {
31    /// A normal conversation message.
32    Regular,
33    /// A compression summary generated by the summarization subsystem.
34    Summary,
35}
36
37impl MessageKind {
38    #[must_use]
39    pub fn is_summary(self) -> bool {
40        matches!(self, Self::Summary)
41    }
42}
43
44const COLLECTION_NAME: &str = "zeph_conversations";
45
46/// Ensure a Qdrant collection exists with cosine distance vectors.
47///
48/// Idempotent: no-op if the collection already exists.
49///
50/// # Errors
51///
52/// Returns an error if Qdrant cannot be reached or collection creation fails.
53pub async fn ensure_qdrant_collection(
54    ops: &QdrantOps,
55    collection: &str,
56    vector_size: u64,
57) -> Result<(), Box<qdrant_client::QdrantError>> {
58    ops.ensure_collection(collection, vector_size).await
59}
60
61/// Typed wrapper over a [`VectorStore`] backend for conversation message embeddings.
62///
63/// Constructed via [`EmbeddingStore::new`] (Qdrant URL + optional API key) or
64/// [`EmbeddingStore::with_store`] (custom backend for testing).
65pub struct EmbeddingStore {
66    ops: Box<dyn VectorStore>,
67    collection: String,
68    pool: DbPool,
69}
70
71impl std::fmt::Debug for EmbeddingStore {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("EmbeddingStore")
74            .field("collection", &self.collection)
75            .finish_non_exhaustive()
76    }
77}
78
79/// Optional filters applied to a vector similarity search.
80#[derive(Debug)]
81pub struct SearchFilter {
82    /// Restrict results to a single conversation. `None` searches across all conversations.
83    pub conversation_id: Option<ConversationId>,
84    /// Restrict by message role (`"user"` / `"assistant"`). `None` returns all roles.
85    pub role: Option<String>,
86    /// Restrict by category payload field (category-aware memory, #2428).
87    /// When `Some`, Qdrant search is restricted to vectors with a matching `category` payload.
88    pub category: Option<String>,
89}
90
91/// A single result returned by [`EmbeddingStore::search`].
92#[derive(Debug)]
93pub struct SearchResult {
94    /// Database row ID of the matching message.
95    pub message_id: MessageId,
96    /// Conversation the message belongs to.
97    pub conversation_id: ConversationId,
98    /// Cosine similarity score in `[0, 1]`.
99    pub score: f32,
100}
101
102impl EmbeddingStore {
103    /// Create a new `EmbeddingStore` connected to the given Qdrant URL with optional API key.
104    ///
105    /// `api_key` is forwarded to [`QdrantOps::new`]. The `pool` is used for `SQLite` metadata
106    /// operations on the `embeddings_metadata` table (which must already exist via sqlx
107    /// migrations).
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the Qdrant client cannot be created.
112    pub fn new(url: &str, api_key: Option<&str>, pool: DbPool) -> Result<Self, MemoryError> {
113        let ops = QdrantOps::new(url, api_key).map_err(MemoryError::Qdrant)?;
114
115        Ok(Self {
116            ops: Box::new(ops),
117            collection: COLLECTION_NAME.into(),
118            pool,
119        })
120    }
121
122    /// Create a new `EmbeddingStore` backed by `SQLite` for vector storage.
123    ///
124    /// Uses the same pool for both vector data and metadata. No external Qdrant required.
125    #[must_use]
126    pub fn new_sqlite(pool: DbPool) -> Self {
127        let ops = DbVectorStore::new(pool.clone());
128        Self {
129            ops: Box::new(ops),
130            collection: COLLECTION_NAME.into(),
131            pool,
132        }
133    }
134
135    /// Create an `EmbeddingStore` backed by an arbitrary [`VectorStore`] implementation.
136    ///
137    /// Intended for testing: inject a pre-configured or mock store without requiring
138    /// an external Qdrant instance.
139    #[must_use]
140    pub fn with_store(store: Box<dyn VectorStore>, pool: DbPool) -> Self {
141        Self {
142            ops: store,
143            collection: COLLECTION_NAME.into(),
144            pool,
145        }
146    }
147
148    /// Return `true` if the backing store is reachable and healthy.
149    pub async fn health_check(&self) -> bool {
150        self.ops.health_check().await.unwrap_or(false)
151    }
152
153    /// Ensure the collection exists in Qdrant with the given vector size.
154    ///
155    /// Idempotent: no-op if the collection already exists.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if Qdrant cannot be reached or collection creation fails.
160    pub async fn ensure_collection(&self, vector_size: u64) -> Result<(), MemoryError> {
161        self.ops
162            .ensure_collection(&self.collection, vector_size)
163            .await?;
164        // Create keyword indexes for the fields used in filtered recall so Qdrant can satisfy
165        // filter conditions in O(log n) instead of scanning all payload documents.
166        self.ops
167            .create_keyword_indexes(&self.collection, &["category", "conversation_id", "role"])
168            .await?;
169        Ok(())
170    }
171
172    /// Store a vector in Qdrant with additional tool execution metadata as payload fields.
173    ///
174    /// Metadata fields (`tool_name`, `exit_code`, `timestamp`) are stored as Qdrant payload
175    /// alongside the standard fields. This allows filtering and scoring by tool context
176    /// without corrupting the embedding vector with text prefixes.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
181    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
182    pub async fn store_with_tool_context(
183        &self,
184        message_id: MessageId,
185        conversation_id: ConversationId,
186        role: &str,
187        vector: Vec<f32>,
188        kind: MessageKind,
189        model: &str,
190        chunk_index: u32,
191        tool_name: &str,
192        exit_code: Option<i32>,
193        timestamp: Option<&str>,
194    ) -> Result<String, MemoryError> {
195        let point_id = uuid::Uuid::new_v4().to_string();
196        let dimensions = i64::try_from(vector.len())?;
197
198        let mut payload = std::collections::HashMap::from([
199            ("message_id".to_owned(), serde_json::json!(message_id.0)),
200            (
201                "conversation_id".to_owned(),
202                serde_json::json!(conversation_id.0),
203            ),
204            ("role".to_owned(), serde_json::json!(role)),
205            (
206                "is_summary".to_owned(),
207                serde_json::json!(kind.is_summary()),
208            ),
209            ("tool_name".to_owned(), serde_json::json!(tool_name)),
210        ]);
211        if let Some(code) = exit_code {
212            payload.insert("exit_code".to_owned(), serde_json::json!(code));
213        }
214        if let Some(ts) = timestamp {
215            payload.insert("timestamp".to_owned(), serde_json::json!(ts));
216        }
217
218        let point = VectorPoint {
219            id: point_id.clone(),
220            vector,
221            payload,
222        };
223
224        self.ops.upsert(&self.collection, vec![point]).await?;
225
226        let chunk_index_i64 = i64::from(chunk_index);
227        zeph_db::query(sql!(
228            "INSERT INTO embeddings_metadata \
229             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
230             VALUES (?, ?, ?, ?, ?) \
231             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
232             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
233        ))
234        .bind(message_id)
235        .bind(chunk_index_i64)
236        .bind(&point_id)
237        .bind(dimensions)
238        .bind(model)
239        .execute(&self.pool)
240        .await?;
241
242        Ok(point_id)
243    }
244
245    /// Store a vector in Qdrant and persist metadata to `SQLite`.
246    ///
247    /// `chunk_index` is 0 for single-vector messages and increases for each chunk
248    /// when a long message is split into multiple embeddings.
249    ///
250    /// Returns the UUID of the newly created Qdrant point.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
255    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
256    pub async fn store(
257        &self,
258        message_id: MessageId,
259        conversation_id: ConversationId,
260        role: &str,
261        vector: Vec<f32>,
262        kind: MessageKind,
263        model: &str,
264        chunk_index: u32,
265    ) -> Result<String, MemoryError> {
266        let point_id = uuid::Uuid::new_v4().to_string();
267        let dimensions = i64::try_from(vector.len())?;
268
269        let payload = std::collections::HashMap::from([
270            ("message_id".to_owned(), serde_json::json!(message_id.0)),
271            (
272                "conversation_id".to_owned(),
273                serde_json::json!(conversation_id.0),
274            ),
275            ("role".to_owned(), serde_json::json!(role)),
276            (
277                "is_summary".to_owned(),
278                serde_json::json!(kind.is_summary()),
279            ),
280        ]);
281
282        let point = VectorPoint {
283            id: point_id.clone(),
284            vector,
285            payload,
286        };
287
288        self.ops.upsert(&self.collection, vec![point]).await?;
289
290        let chunk_index_i64 = i64::from(chunk_index);
291        zeph_db::query(sql!(
292            "INSERT INTO embeddings_metadata \
293             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
294             VALUES (?, ?, ?, ?, ?) \
295             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
296             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
297        ))
298        .bind(message_id)
299        .bind(chunk_index_i64)
300        .bind(&point_id)
301        .bind(dimensions)
302        .bind(model)
303        .execute(&self.pool)
304        .await?;
305
306        Ok(point_id)
307    }
308
309    /// Store a vector with an optional category tag in the Qdrant payload.
310    ///
311    /// Identical to [`Self::store`] but adds a `category` field to the payload when provided.
312    /// Used by category-aware memory (#2428) to enable category-filtered recall.
313    ///
314    /// Note: when `category` is `None` no `category` field is written to the Qdrant payload.
315    /// Memories stored before category-aware recall was enabled therefore won't match a
316    /// category filter — this is intentional (no silent false-positives), but a backfill
317    /// pass is needed if retrospective categorization is desired.
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if the Qdrant upsert or `SQLite` insert fails.
322    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
323    pub async fn store_with_category(
324        &self,
325        message_id: MessageId,
326        conversation_id: ConversationId,
327        role: &str,
328        vector: Vec<f32>,
329        kind: MessageKind,
330        model: &str,
331        chunk_index: u32,
332        category: Option<&str>,
333    ) -> Result<String, MemoryError> {
334        let point_id = uuid::Uuid::new_v4().to_string();
335        let dimensions = i64::try_from(vector.len())?;
336
337        let mut payload = std::collections::HashMap::from([
338            ("message_id".to_owned(), serde_json::json!(message_id.0)),
339            (
340                "conversation_id".to_owned(),
341                serde_json::json!(conversation_id.0),
342            ),
343            ("role".to_owned(), serde_json::json!(role)),
344            (
345                "is_summary".to_owned(),
346                serde_json::json!(kind.is_summary()),
347            ),
348        ]);
349        if let Some(cat) = category {
350            payload.insert("category".to_owned(), serde_json::json!(cat));
351        }
352
353        let point = VectorPoint {
354            id: point_id.clone(),
355            vector,
356            payload,
357        };
358
359        self.ops.upsert(&self.collection, vec![point]).await?;
360
361        let chunk_index_i64 = i64::from(chunk_index);
362        zeph_db::query(sql!(
363            "INSERT INTO embeddings_metadata \
364             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
365             VALUES (?, ?, ?, ?, ?) \
366             ON CONFLICT(message_id, chunk_index, model) DO UPDATE SET \
367             qdrant_point_id = excluded.qdrant_point_id, dimensions = excluded.dimensions"
368        ))
369        .bind(message_id)
370        .bind(chunk_index_i64)
371        .bind(&point_id)
372        .bind(dimensions)
373        .bind(model)
374        .execute(&self.pool)
375        .await?;
376
377        Ok(point_id)
378    }
379
380    /// Search for similar vectors in Qdrant, returning up to `limit` results.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if the Qdrant search fails.
385    pub async fn search(
386        &self,
387        query_vector: &[f32],
388        limit: usize,
389        filter: Option<SearchFilter>,
390    ) -> Result<Vec<SearchResult>, MemoryError> {
391        let limit_u64 = u64::try_from(limit)?;
392
393        let vector_filter = filter.as_ref().and_then(|f| {
394            let mut must = Vec::new();
395            if let Some(cid) = f.conversation_id {
396                must.push(FieldCondition {
397                    field: "conversation_id".into(),
398                    value: FieldValue::Integer(cid.0),
399                });
400            }
401            if let Some(ref role) = f.role {
402                must.push(FieldCondition {
403                    field: "role".into(),
404                    value: FieldValue::Text(role.clone()),
405                });
406            }
407            if let Some(ref category) = f.category {
408                must.push(FieldCondition {
409                    field: "category".into(),
410                    value: FieldValue::Text(category.clone()),
411                });
412            }
413            if must.is_empty() {
414                None
415            } else {
416                Some(VectorFilter {
417                    must,
418                    must_not: vec![],
419                })
420            }
421        });
422
423        let results = self
424            .ops
425            .search(
426                &self.collection,
427                query_vector.to_vec(),
428                limit_u64,
429                vector_filter,
430            )
431            .await?;
432
433        // Deduplicate by message_id, keeping the chunk with the highest score.
434        // A single message may produce multiple Qdrant points (one per chunk).
435        let mut best: std::collections::HashMap<MessageId, SearchResult> =
436            std::collections::HashMap::new();
437        for point in results {
438            let Some(message_id) = point
439                .payload
440                .get("message_id")
441                .and_then(serde_json::Value::as_i64)
442            else {
443                continue;
444            };
445            let Some(conversation_id) = point
446                .payload
447                .get("conversation_id")
448                .and_then(serde_json::Value::as_i64)
449            else {
450                continue;
451            };
452            let message_id = MessageId(message_id);
453            let entry = best.entry(message_id).or_insert(SearchResult {
454                message_id,
455                conversation_id: ConversationId(conversation_id),
456                score: f32::NEG_INFINITY,
457            });
458            if point.score > entry.score {
459                entry.score = point.score;
460            }
461        }
462
463        let mut search_results: Vec<SearchResult> = best.into_values().collect();
464        search_results.sort_by(|a, b| {
465            b.score
466                .partial_cmp(&a.score)
467                .unwrap_or(std::cmp::Ordering::Equal)
468        });
469        search_results.truncate(limit);
470
471        Ok(search_results)
472    }
473
474    /// Check whether a named collection exists in the vector store.
475    ///
476    /// # Errors
477    ///
478    /// Returns an error if the store backend cannot be reached.
479    pub async fn collection_exists(&self, name: &str) -> Result<bool, MemoryError> {
480        self.ops.collection_exists(name).await.map_err(Into::into)
481    }
482
483    /// Ensure a named collection exists in Qdrant with the given vector size.
484    ///
485    /// # Errors
486    ///
487    /// Returns an error if Qdrant cannot be reached or collection creation fails.
488    pub async fn ensure_named_collection(
489        &self,
490        name: &str,
491        vector_size: u64,
492    ) -> Result<(), MemoryError> {
493        self.ops.ensure_collection(name, vector_size).await?;
494        Ok(())
495    }
496
497    /// Store a vector in a named Qdrant collection with arbitrary payload.
498    ///
499    /// Returns the UUID of the newly created point.
500    ///
501    /// # Errors
502    ///
503    /// Returns an error if the Qdrant upsert fails.
504    pub async fn store_to_collection(
505        &self,
506        collection: &str,
507        payload: serde_json::Value,
508        vector: Vec<f32>,
509    ) -> Result<String, MemoryError> {
510        let point_id = uuid::Uuid::new_v4().to_string();
511        let payload_map: std::collections::HashMap<String, serde_json::Value> =
512            serde_json::from_value(payload)?;
513        let point = VectorPoint {
514            id: point_id.clone(),
515            vector,
516            payload: payload_map,
517        };
518        self.ops.upsert(collection, vec![point]).await?;
519        Ok(point_id)
520    }
521
522    /// Upsert a vector into a named collection, reusing an existing point ID.
523    ///
524    /// Use this when updating an existing entity to avoid orphaned Qdrant points.
525    ///
526    /// # Errors
527    ///
528    /// Returns an error if the Qdrant upsert fails.
529    pub async fn upsert_to_collection(
530        &self,
531        collection: &str,
532        point_id: &str,
533        payload: serde_json::Value,
534        vector: Vec<f32>,
535    ) -> Result<(), MemoryError> {
536        let payload_map: std::collections::HashMap<String, serde_json::Value> =
537            serde_json::from_value(payload)?;
538        let point = VectorPoint {
539            id: point_id.to_owned(),
540            vector,
541            payload: payload_map,
542        };
543        self.ops.upsert(collection, vec![point]).await?;
544        Ok(())
545    }
546
547    /// Search a named Qdrant collection, returning scored points with payloads.
548    ///
549    /// # Errors
550    ///
551    /// Returns an error if the Qdrant search fails.
552    pub async fn search_collection(
553        &self,
554        collection: &str,
555        query_vector: &[f32],
556        limit: usize,
557        filter: Option<VectorFilter>,
558    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
559        let limit_u64 = u64::try_from(limit)?;
560        let results = self
561            .ops
562            .search(collection, query_vector.to_vec(), limit_u64, filter)
563            .await?;
564        Ok(results)
565    }
566
567    /// Enumerate `(point_id, entity_id)` pairs for all points in `collection` that carry
568    /// an `entity_id_str` payload field.
569    ///
570    /// `entity_id_str` is a string mirror of the i64 `entity_id` written alongside the numeric
571    /// field at embedding time. The scroll API only surfaces string-typed payload values, so a
572    /// parallel string field is necessary for enumeration. Points missing `entity_id_str`
573    /// (written before this field was added) are silently skipped — they will gain the field on
574    /// the next `merge_entity` or `store_entity_embedding` call.
575    ///
576    /// # Errors
577    ///
578    /// Returns an error if the underlying scroll operation fails.
579    pub async fn scroll_all_entity_ids(
580        &self,
581        collection: &str,
582    ) -> Result<Vec<(String, i64)>, MemoryError> {
583        let rows = self
584            .ops
585            .scroll_all_with_point_ids(collection, "entity_id_str")
586            .await?;
587        let mut out = Vec::with_capacity(rows.len());
588        for (point_id, fields) in rows {
589            let Some(s) = fields.get("entity_id_str") else {
590                continue;
591            };
592            if let Ok(id) = s.parse::<i64>() {
593                out.push((point_id, id));
594            } else {
595                tracing::debug!(point_id, value = %s, "entity_id_str unparseable, skipping");
596            }
597        }
598        Ok(out)
599    }
600
601    /// Delete a set of points from a named collection by their Qdrant point IDs.
602    ///
603    /// This is a thin wrapper over [`VectorStore::delete_by_ids`] for use by
604    /// the stale-embedding cleanup path in `community.rs`.
605    ///
606    /// # Errors
607    ///
608    /// Returns an error if the underlying delete operation fails.
609    pub async fn delete_from_collection(
610        &self,
611        collection: &str,
612        ids: Vec<String>,
613    ) -> Result<(), MemoryError> {
614        if ids.is_empty() {
615            return Ok(());
616        }
617        self.ops.delete_by_ids(collection, ids).await?;
618        Ok(())
619    }
620
621    /// Retrieve raw vectors for the given Qdrant point IDs from `collection`.
622    ///
623    /// Returns a map of `point_id → embedding`. Missing ids are silently dropped.
624    /// Returns an empty map when the backend does not support vector retrieval
625    /// (e.g. `DbVectorStore` / `InMemoryVectorStore` without an override).
626    ///
627    /// # Errors
628    ///
629    /// Returns an error if the underlying store returns a non-`Unsupported` error.
630    pub async fn get_vectors_from_collection(
631        &self,
632        collection: &str,
633        point_ids: &[String],
634    ) -> Result<std::collections::HashMap<String, Vec<f32>>, MemoryError> {
635        if point_ids.is_empty() {
636            return Ok(std::collections::HashMap::new());
637        }
638        match self.ops.get_points(collection, point_ids.to_vec()).await {
639            Ok(points) => Ok(points.into_iter().map(|p| (p.id, p.vector)).collect()),
640            Err(crate::VectorStoreError::Unsupported(_)) => Ok(std::collections::HashMap::new()),
641            Err(e) => Err(MemoryError::VectorStore(e)),
642        }
643    }
644
645    /// Fetch raw vectors for the given message IDs from the `SQLite` vector store.
646    ///
647    /// Returns an empty map when using Qdrant backend (vectors not locally stored).
648    ///
649    /// # Errors
650    ///
651    /// Returns an error if the `SQLite` query fails.
652    pub async fn get_vectors(
653        &self,
654        ids: &[MessageId],
655    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
656        if ids.is_empty() {
657            return Ok(std::collections::HashMap::new());
658        }
659
660        let placeholders = zeph_db::placeholder_list(1, ids.len());
661        let query = format!(
662            "SELECT em.message_id, vp.vector \
663             FROM embeddings_metadata em \
664             JOIN vector_points vp ON vp.id = em.qdrant_point_id \
665             WHERE em.message_id IN ({placeholders}) AND em.chunk_index = 0"
666        );
667        let mut q = zeph_db::query_as::<_, (MessageId, Vec<u8>)>(&query);
668        for &id in ids {
669            q = q.bind(id);
670        }
671
672        let rows = q.fetch_all(&self.pool).await?;
673
674        let map = rows
675            .into_iter()
676            .filter_map(|(msg_id, blob)| {
677                if blob.len() % 4 != 0 {
678                    return None;
679                }
680                let vec: Vec<f32> = blob
681                    .chunks_exact(4)
682                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
683                    .collect();
684                Some((msg_id, vec))
685            })
686            .collect();
687
688        Ok(map)
689    }
690
691    /// Fetch embeddings for the given message IDs from the configured vector store.
692    ///
693    /// Resolves `message_id → qdrant_point_id` via `embeddings_metadata` (filtering to
694    /// `chunk_index = 0` so each message yields at most one vector), then retrieves the
695    /// vectors from the underlying [`VectorStore`].
696    ///
697    /// Returns a map from [`MessageId`] to embedding vector. Messages without an
698    /// `embeddings_metadata` row, or whose vector cannot be retrieved, are silently dropped.
699    /// When the backend returns [`crate::VectorStoreError::Unsupported`], an empty map is
700    /// returned without error (matches [`Self::get_vectors_from_collection`] semantics).
701    ///
702    /// # Errors
703    ///
704    /// Returns an error if the `SQLite` metadata query or vector store retrieval fails.
705    pub async fn get_vectors_for_messages(
706        &self,
707        ids: &[MessageId],
708    ) -> Result<std::collections::HashMap<MessageId, Vec<f32>>, MemoryError> {
709        if ids.is_empty() {
710            return Ok(std::collections::HashMap::new());
711        }
712
713        let placeholders = zeph_db::placeholder_list(1, ids.len());
714        let query = format!(
715            "SELECT message_id, qdrant_point_id \
716             FROM embeddings_metadata \
717             WHERE message_id IN ({placeholders}) AND chunk_index = 0"
718        );
719        let mut q = zeph_db::query_as::<_, (MessageId, String)>(&query);
720        for &id in ids {
721            q = q.bind(id);
722        }
723        let rows: Vec<(MessageId, String)> = q.fetch_all(&self.pool).await?;
724
725        if rows.is_empty() {
726            return Ok(std::collections::HashMap::new());
727        }
728
729        // Build reverse map: point_id → message_id for result translation.
730        let mut point_to_msg: std::collections::HashMap<String, MessageId> =
731            std::collections::HashMap::with_capacity(rows.len());
732        let point_ids: Vec<String> = rows
733            .into_iter()
734            .map(|(msg_id, point_id)| {
735                point_to_msg.insert(point_id.clone(), msg_id);
736                point_id
737            })
738            .collect();
739
740        let points = match self.ops.get_points(&self.collection, point_ids).await {
741            Ok(pts) => pts,
742            Err(crate::VectorStoreError::Unsupported(_)) => {
743                return Ok(std::collections::HashMap::new());
744            }
745            Err(e) => return Err(MemoryError::VectorStore(e)),
746        };
747
748        let result = points
749            .into_iter()
750            .filter_map(|p| {
751                let msg_id = point_to_msg.get(&p.id).copied()?;
752                Some((msg_id, p.vector))
753            })
754            .collect();
755
756        Ok(result)
757    }
758
759    /// Delete all Qdrant vectors associated with the given message IDs.
760    ///
761    /// Resolves `message_id → qdrant_point_id` via the `embeddings_metadata` table,
762    /// then calls the underlying vector store's `delete_by_ids`. The
763    /// `embeddings_metadata` rows are **not** removed here — the `SQLite` CASCADE on
764    /// `messages` handles that when the rows are hard-deleted later.
765    ///
766    /// Returns the number of Qdrant point IDs targeted for deletion (may be less than
767    /// `ids.len()` when some messages have no embeddings).
768    ///
769    /// # Errors
770    ///
771    /// Returns [`MemoryError`] if the `SQLite` query or the vector store delete fails.
772    pub async fn delete_by_message_ids(&self, ids: &[MessageId]) -> Result<usize, MemoryError> {
773        if ids.is_empty() {
774            return Ok(0);
775        }
776
777        let placeholders = zeph_db::placeholder_list(1, ids.len());
778        let query = format!(
779            "SELECT qdrant_point_id FROM embeddings_metadata WHERE message_id IN ({placeholders})"
780        );
781        let mut q = zeph_db::query_as::<_, (String,)>(&query);
782        for &id in ids {
783            q = q.bind(id);
784        }
785        let rows: Vec<(String,)> = q.fetch_all(&self.pool).await?;
786
787        let point_ids: Vec<String> = rows.into_iter().map(|(id,)| id).collect();
788        let count = point_ids.len();
789
790        if !point_ids.is_empty() {
791            self.ops.delete_by_ids(&self.collection, point_ids).await?;
792        }
793
794        Ok(count)
795    }
796
797    /// Check whether an embedding already exists for the given message ID.
798    ///
799    /// # Errors
800    ///
801    /// Returns an error if the `SQLite` query fails.
802    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
803        let row: (i64,) = zeph_db::query_as(sql!(
804            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
805        ))
806        .bind(message_id)
807        .fetch_one(&self.pool)
808        .await?;
809
810        Ok(row.0 > 0)
811    }
812
813    /// Check whether a Qdrant embedding for `entity_name` is current by comparing the
814    /// Qdrant-side epoch against the epoch stored in `graph_entities`.
815    ///
816    /// Returns `true` if the Qdrant embedding is up-to-date or if the entity no longer
817    /// exists in `SQLite` (embedding should be cleaned up separately).
818    ///
819    /// # Errors
820    ///
821    /// Returns an error if the `SQLite` query fails.
822    pub async fn is_epoch_current(
823        &self,
824        entity_name: &str,
825        qdrant_epoch: u64,
826    ) -> Result<bool, MemoryError> {
827        let row: Option<(i64,)> = zeph_db::query_as(sql!(
828            "SELECT embedding_epoch FROM graph_entities WHERE name = ? LIMIT 1"
829        ))
830        .bind(entity_name)
831        .fetch_optional(&self.pool)
832        .await?;
833
834        match row {
835            None => Ok(true), // entity deleted; Qdrant point is orphaned, not stale per epoch
836            Some((db_epoch,)) => Ok(qdrant_epoch >= db_epoch.cast_unsigned()),
837        }
838    }
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844    use crate::in_memory_store::InMemoryVectorStore;
845    use crate::store::SqliteStore;
846
847    async fn setup() -> (SqliteStore, DbPool) {
848        let store = SqliteStore::new(":memory:").await.unwrap();
849        let pool = store.pool().clone();
850        (store, pool)
851    }
852
853    async fn setup_with_store() -> (EmbeddingStore, SqliteStore) {
854        let sqlite = SqliteStore::new(":memory:").await.unwrap();
855        let pool = sqlite.pool().clone();
856        let mem_store = Box::new(InMemoryVectorStore::new());
857        let embedding_store = EmbeddingStore::with_store(mem_store, pool);
858        // Create collection first
859        embedding_store.ensure_collection(4).await.unwrap();
860        (embedding_store, sqlite)
861    }
862
863    #[tokio::test]
864    async fn has_embedding_returns_false_when_none() {
865        let (_store, pool) = setup().await;
866
867        let row: (i64,) = zeph_db::query_as(sql!(
868            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
869        ))
870        .bind(999_i64)
871        .fetch_one(&pool)
872        .await
873        .unwrap();
874
875        assert_eq!(row.0, 0);
876    }
877
878    #[tokio::test]
879    async fn insert_and_query_embeddings_metadata() {
880        let (sqlite, pool) = setup().await;
881        let cid = sqlite.create_conversation().await.unwrap();
882        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
883
884        let point_id = uuid::Uuid::new_v4().to_string();
885        zeph_db::query(sql!(
886            "INSERT INTO embeddings_metadata \
887             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
888             VALUES (?, ?, ?, ?, ?)"
889        ))
890        .bind(msg_id)
891        .bind(0_i64)
892        .bind(&point_id)
893        .bind(768_i64)
894        .bind("qwen3-embedding")
895        .execute(&pool)
896        .await
897        .unwrap();
898
899        let row: (i64,) = zeph_db::query_as(sql!(
900            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
901        ))
902        .bind(msg_id)
903        .fetch_one(&pool)
904        .await
905        .unwrap();
906        assert_eq!(row.0, 1);
907    }
908
909    #[tokio::test]
910    async fn embedding_store_search_empty_returns_empty() {
911        let (store, _sqlite) = setup_with_store().await;
912        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 10, None).await.unwrap();
913        assert!(results.is_empty());
914    }
915
916    #[tokio::test]
917    async fn embedding_store_store_and_search() {
918        let (store, sqlite) = setup_with_store().await;
919        let cid = sqlite.create_conversation().await.unwrap();
920        let msg_id = sqlite
921            .save_message(cid, "user", "test message")
922            .await
923            .unwrap();
924
925        store
926            .store(
927                msg_id,
928                cid,
929                "user",
930                vec![1.0, 0.0, 0.0, 0.0],
931                MessageKind::Regular,
932                "test-model",
933                0,
934            )
935            .await
936            .unwrap();
937
938        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 5, None).await.unwrap();
939        assert_eq!(results.len(), 1);
940        assert_eq!(results[0].message_id, msg_id);
941        assert_eq!(results[0].conversation_id, cid);
942        assert!((results[0].score - 1.0).abs() < 0.001);
943    }
944
945    #[tokio::test]
946    async fn embedding_store_has_embedding_false_for_unknown() {
947        let (store, sqlite) = setup_with_store().await;
948        let cid = sqlite.create_conversation().await.unwrap();
949        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
950        assert!(!store.has_embedding(msg_id).await.unwrap());
951    }
952
953    #[tokio::test]
954    async fn embedding_store_has_embedding_true_after_store() {
955        let (store, sqlite) = setup_with_store().await;
956        let cid = sqlite.create_conversation().await.unwrap();
957        let msg_id = sqlite.save_message(cid, "user", "hello").await.unwrap();
958
959        store
960            .store(
961                msg_id,
962                cid,
963                "user",
964                vec![0.0, 1.0, 0.0, 0.0],
965                MessageKind::Regular,
966                "test-model",
967                0,
968            )
969            .await
970            .unwrap();
971
972        assert!(store.has_embedding(msg_id).await.unwrap());
973    }
974
975    #[tokio::test]
976    async fn embedding_store_search_with_conversation_filter() {
977        let (store, sqlite) = setup_with_store().await;
978        let cid1 = sqlite.create_conversation().await.unwrap();
979        let cid2 = sqlite.create_conversation().await.unwrap();
980        let msg1 = sqlite.save_message(cid1, "user", "msg1").await.unwrap();
981        let msg2 = sqlite.save_message(cid2, "user", "msg2").await.unwrap();
982
983        store
984            .store(
985                msg1,
986                cid1,
987                "user",
988                vec![1.0, 0.0, 0.0, 0.0],
989                MessageKind::Regular,
990                "m",
991                0,
992            )
993            .await
994            .unwrap();
995        store
996            .store(
997                msg2,
998                cid2,
999                "user",
1000                vec![1.0, 0.0, 0.0, 0.0],
1001                MessageKind::Regular,
1002                "m",
1003                0,
1004            )
1005            .await
1006            .unwrap();
1007
1008        let results = store
1009            .search(
1010                &[1.0, 0.0, 0.0, 0.0],
1011                10,
1012                Some(SearchFilter {
1013                    conversation_id: Some(cid1),
1014                    role: None,
1015                    category: None,
1016                }),
1017            )
1018            .await
1019            .unwrap();
1020        assert_eq!(results.len(), 1);
1021        assert_eq!(results[0].conversation_id, cid1);
1022    }
1023
1024    #[tokio::test]
1025    async fn unique_constraint_on_message_chunk_and_model() {
1026        let (sqlite, pool) = setup().await;
1027        let cid = sqlite.create_conversation().await.unwrap();
1028        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1029
1030        let point_id1 = uuid::Uuid::new_v4().to_string();
1031        zeph_db::query(sql!(
1032            "INSERT INTO embeddings_metadata \
1033             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1034             VALUES (?, ?, ?, ?, ?)"
1035        ))
1036        .bind(msg_id)
1037        .bind(0_i64)
1038        .bind(&point_id1)
1039        .bind(768_i64)
1040        .bind("qwen3-embedding")
1041        .execute(&pool)
1042        .await
1043        .unwrap();
1044
1045        // Same (message_id, chunk_index, model) — must fail.
1046        let point_id2 = uuid::Uuid::new_v4().to_string();
1047        let result = zeph_db::query(sql!(
1048            "INSERT INTO embeddings_metadata \
1049             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1050             VALUES (?, ?, ?, ?, ?)"
1051        ))
1052        .bind(msg_id)
1053        .bind(0_i64)
1054        .bind(&point_id2)
1055        .bind(768_i64)
1056        .bind("qwen3-embedding")
1057        .execute(&pool)
1058        .await;
1059        assert!(result.is_err());
1060
1061        // Different chunk_index — must succeed.
1062        let point_id3 = uuid::Uuid::new_v4().to_string();
1063        zeph_db::query(sql!(
1064            "INSERT INTO embeddings_metadata \
1065             (message_id, chunk_index, qdrant_point_id, dimensions, model) \
1066             VALUES (?, ?, ?, ?, ?)"
1067        ))
1068        .bind(msg_id)
1069        .bind(1_i64)
1070        .bind(&point_id3)
1071        .bind(768_i64)
1072        .bind("qwen3-embedding")
1073        .execute(&pool)
1074        .await
1075        .unwrap();
1076    }
1077
1078    #[tokio::test]
1079    async fn get_vectors_for_messages_returns_correct_vectors() {
1080        let (store, sqlite) = setup_with_store().await;
1081        let cid = sqlite.create_conversation().await.unwrap();
1082        let msg1 = sqlite.save_message(cid, "user", "hello").await.unwrap();
1083        let msg2 = sqlite.save_message(cid, "user", "world").await.unwrap();
1084
1085        store
1086            .store(
1087                msg1,
1088                cid,
1089                "user",
1090                vec![1.0, 0.0, 0.0, 0.0],
1091                MessageKind::Regular,
1092                "m",
1093                0,
1094            )
1095            .await
1096            .unwrap();
1097        store
1098            .store(
1099                msg2,
1100                cid,
1101                "user",
1102                vec![0.0, 1.0, 0.0, 0.0],
1103                MessageKind::Regular,
1104                "m",
1105                0,
1106            )
1107            .await
1108            .unwrap();
1109
1110        let result = store.get_vectors_for_messages(&[msg1, msg2]).await.unwrap();
1111        assert_eq!(result.len(), 2);
1112        let v1 = result.get(&msg1).unwrap();
1113        let v2 = result.get(&msg2).unwrap();
1114        assert!((v1[0] - 1.0).abs() < f32::EPSILON);
1115        assert!((v2[1] - 1.0).abs() < f32::EPSILON);
1116    }
1117
1118    #[tokio::test]
1119    async fn get_vectors_for_messages_missing_id_is_dropped() {
1120        let (store, sqlite) = setup_with_store().await;
1121        let cid = sqlite.create_conversation().await.unwrap();
1122        let msg1 = sqlite.save_message(cid, "user", "present").await.unwrap();
1123        let msg_absent = MessageId(99_999);
1124
1125        store
1126            .store(
1127                msg1,
1128                cid,
1129                "user",
1130                vec![1.0, 0.0, 0.0, 0.0],
1131                MessageKind::Regular,
1132                "m",
1133                0,
1134            )
1135            .await
1136            .unwrap();
1137
1138        let result = store
1139            .get_vectors_for_messages(&[msg1, msg_absent])
1140            .await
1141            .unwrap();
1142        assert_eq!(result.len(), 1);
1143        assert!(result.contains_key(&msg1));
1144        assert!(!result.contains_key(&msg_absent));
1145    }
1146
1147    #[tokio::test]
1148    async fn get_vectors_for_messages_empty_input() {
1149        let (store, _sqlite) = setup_with_store().await;
1150        let result = store.get_vectors_for_messages(&[]).await.unwrap();
1151        assert!(result.is_empty());
1152    }
1153
1154    #[tokio::test]
1155    async fn get_vectors_for_messages_chunk_index_0_only() {
1156        // Store chunk_index=0 and chunk_index=1; only chunk_index=0 should be returned.
1157        let (store, sqlite) = setup_with_store().await;
1158        let cid = sqlite.create_conversation().await.unwrap();
1159        let msg = sqlite.save_message(cid, "user", "chunked").await.unwrap();
1160
1161        store
1162            .store(
1163                msg,
1164                cid,
1165                "user",
1166                vec![1.0, 0.0, 0.0, 0.0],
1167                MessageKind::Regular,
1168                "m",
1169                0,
1170            )
1171            .await
1172            .unwrap();
1173        store
1174            .store(
1175                msg,
1176                cid,
1177                "user",
1178                vec![0.0, 0.0, 1.0, 0.0],
1179                MessageKind::Regular,
1180                "m",
1181                1,
1182            )
1183            .await
1184            .unwrap();
1185
1186        let result = store.get_vectors_for_messages(&[msg]).await.unwrap();
1187        assert_eq!(result.len(), 1);
1188        // Must be the chunk_index=0 vector
1189        let v = result.get(&msg).unwrap();
1190        assert!(
1191            (v[0] - 1.0).abs() < f32::EPSILON,
1192            "expected chunk_index=0 vector"
1193        );
1194    }
1195
1196    /// `delete_by_message_ids` resolves `message_id → qdrant_point_id` via
1197    /// `embeddings_metadata` and deletes the matching vectors.
1198    ///
1199    /// Verifies: (a) the correct point id is targeted, (b) `embeddings_metadata`
1200    /// rows are NOT removed (CASCADE handles that on hard-delete later), and (c) the
1201    /// method returns the number of point IDs found.
1202    #[tokio::test]
1203    async fn embedding_store_delete_by_message_ids_resolves_via_metadata() {
1204        let (store, sqlite) = setup_with_store().await;
1205        let cid = sqlite.create_conversation().await.unwrap();
1206        let msg_id = sqlite.save_message(cid, "user", "test").await.unwrap();
1207
1208        // Store a vector so embeddings_metadata gets a row.
1209        store
1210            .store(
1211                msg_id,
1212                cid,
1213                "user",
1214                vec![1.0, 0.0, 0.0, 0.0],
1215                MessageKind::Regular,
1216                "test-model",
1217                0,
1218            )
1219            .await
1220            .unwrap();
1221
1222        // Confirm the metadata row exists before deletion.
1223        assert!(store.has_embedding(msg_id).await.unwrap());
1224
1225        // Delete by message id — must succeed and return 1 (one point id resolved).
1226        let deleted = store.delete_by_message_ids(&[msg_id]).await.unwrap();
1227        assert_eq!(deleted, 1, "one point id should have been targeted");
1228
1229        // embeddings_metadata rows must still be present (CASCADE removes them later).
1230        let pool = sqlite.pool().clone();
1231        let row: (i64,) = zeph_db::query_as(sql!(
1232            "SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?"
1233        ))
1234        .bind(msg_id)
1235        .fetch_one(&pool)
1236        .await
1237        .unwrap();
1238        assert_eq!(
1239            row.0, 1,
1240            "embeddings_metadata row must survive delete_by_message_ids"
1241        );
1242    }
1243
1244    /// `delete_by_message_ids` is a no-op when the slice is empty.
1245    #[tokio::test]
1246    async fn embedding_store_delete_by_message_ids_empty_slice_is_noop() {
1247        let (store, _sqlite) = setup_with_store().await;
1248        let deleted = store.delete_by_message_ids(&[]).await.unwrap();
1249        assert_eq!(deleted, 0);
1250    }
1251}