Skip to main content

vex_persist/
vector_store.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VectorError {
10    #[error("Dimension mismatch: expected {0}, got {1}")]
11    DimensionMismatch(usize, usize),
12    #[error("Serialization error: {0}")]
13    SerializationError(String),
14    #[error("Database error: {0}")]
15    DatabaseError(String),
16    #[error("Storage full: capacity exceeded")]
17    StorageFull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VectorEmbedding {
22    pub id: String,
23    pub vector: Vec<f32>,
24    pub metadata: HashMap<String, String>,
25}
26
27/// Generic trait for vector storage
28#[async_trait]
29pub trait VectorStoreBackend: Send + Sync + std::fmt::Debug {
30    async fn add(
31        &self,
32        id: String,
33        tenant_id: String,
34        vector: Vec<f32>,
35        metadata: HashMap<String, String>,
36    ) -> Result<(), VectorError>;
37
38    async fn search(
39        &self,
40        tenant_id: &str,
41        query: &[f32],
42        k: usize,
43        filters: Option<HashMap<String, String>>,
44    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError>;
45}
46
47/// In-memory vector store implementation (for testing and small contexts)
48#[derive(Debug, Clone)]
49pub struct MemoryVectorStore {
50    dimension: usize,
51    embeddings: Arc<RwLock<Vec<(String, String, VectorEmbedding)>>>, // (id, tenant_id, embedding)
52}
53
54impl MemoryVectorStore {
55    pub fn new(dimension: usize) -> Self {
56        Self {
57            dimension,
58            embeddings: Arc::new(RwLock::new(Vec::new())),
59        }
60    }
61}
62
63#[async_trait]
64impl VectorStoreBackend for MemoryVectorStore {
65    async fn add(
66        &self,
67        id: String,
68        tenant_id: String,
69        vector: Vec<f32>,
70        metadata: HashMap<String, String>,
71    ) -> Result<(), VectorError> {
72        if vector.len() != self.dimension {
73            return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
74        }
75
76        let mut data = self.embeddings.write().unwrap();
77
78        // Limit capacity to prevent memory DoS (Fix #12)
79        if data.len() >= 100_000 {
80            return Err(VectorError::StorageFull);
81        }
82
83        data.push((
84            id.clone(),
85            tenant_id,
86            VectorEmbedding {
87                id,
88                vector,
89                metadata,
90            },
91        ));
92
93        Ok(())
94    }
95
96    async fn search(
97        &self,
98        tenant_id: &str,
99        query: &[f32],
100        k: usize,
101        filters: Option<HashMap<String, String>>,
102    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
103        if query.len() != self.dimension {
104            return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
105        }
106
107        let data = self.embeddings.read().unwrap();
108        let mut scores: Vec<(f32, VectorEmbedding)> = data
109            .iter()
110            .filter(|(_, tid, emb)| {
111                if tid != tenant_id {
112                    return false;
113                }
114
115                // Apply metadata filters
116                if let Some(ref f) = filters {
117                    for (key, val) in f {
118                        if emb.metadata.get(key) != Some(val) {
119                            return false;
120                        }
121                    }
122                }
123
124                true
125            })
126            .map(|(_, _, emb)| {
127                let score = cosine_similarity(query, &emb.vector);
128                (score, emb.clone())
129            })
130            .collect();
131
132        scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
133        scores.truncate(k);
134
135        Ok(scores)
136    }
137}
138
139/// SQLite-backed persistent vector store
140#[derive(Debug, Clone)]
141pub struct SqliteVectorStore {
142    dimension: usize,
143    pool: SqlitePool,
144}
145
146impl SqliteVectorStore {
147    pub fn new(dimension: usize, pool: SqlitePool) -> Self {
148        Self { dimension, pool }
149    }
150}
151
152#[async_trait]
153impl VectorStoreBackend for SqliteVectorStore {
154    async fn add(
155        &self,
156        id: String,
157        tenant_id: String,
158        vector: Vec<f32>,
159        metadata: HashMap<String, String>,
160    ) -> Result<(), VectorError> {
161        if vector.len() != self.dimension {
162            return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
163        }
164
165        // Convert f32 vector to bytes (Little Endian)
166        let mut vector_bytes = Vec::with_capacity(vector.len() * 4);
167        for &val in &vector {
168            vector_bytes.extend_from_slice(&val.to_le_bytes());
169        }
170
171        let metadata_json = serde_json::to_string(&metadata)
172            .map_err(|e| VectorError::SerializationError(e.to_string()))?;
173
174        sqlx::query(
175            "INSERT OR REPLACE INTO vector_embeddings (id, tenant_id, vector, metadata, created_at) VALUES (?, ?, ?, ?, ?)"
176        )
177        .bind(id)
178        .bind(tenant_id)
179        .bind(vector_bytes)
180        .bind(metadata_json)
181        .bind(chrono::Utc::now().timestamp())
182        .execute(&self.pool)
183        .await
184        .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
185
186        Ok(())
187    }
188
189    async fn search(
190        &self,
191        tenant_id: &str,
192        query: &[f32],
193        k: usize,
194        filters: Option<HashMap<String, String>>,
195    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
196        if query.len() != self.dimension {
197            return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
198        }
199
200        let mut sql =
201            "SELECT id, vector, metadata FROM vector_embeddings WHERE tenant_id = ?".to_string();
202        if let Some(ref f) = filters {
203            for key in f.keys() {
204                sql.push_str(&format!(" AND json_extract(metadata, '$.{}') = ?", key));
205            }
206        }
207
208        let mut q = sqlx::query(&sql).bind(tenant_id);
209
210        if let Some(ref f) = filters {
211            for val in f.values() {
212                q = q.bind(val);
213            }
214        }
215
216        let rows = q
217            .fetch_all(&self.pool)
218            .await
219            .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
220
221        let mut scores = Vec::new();
222
223        for row in rows {
224            let id: String = row.get("id");
225            let vector_bytes: Vec<u8> = row.get("vector");
226            let metadata_str: String = row.get("metadata");
227
228            // Convert bytes back to f32 vector
229            if vector_bytes.len() != self.dimension * 4 {
230                continue; // Skip corrupted entry
231            }
232
233            let mut vector = Vec::with_capacity(self.dimension);
234            for chunk in vector_bytes.chunks_exact(4) {
235                let arr: [u8; 4] = chunk.try_into().unwrap();
236                vector.push(f32::from_le_bytes(arr));
237            }
238
239            let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
240                .map_err(|e| VectorError::SerializationError(e.to_string()))?;
241
242            let score = cosine_similarity(query, &vector);
243            scores.push((
244                score,
245                VectorEmbedding {
246                    id,
247                    vector,
248                    metadata,
249                },
250            ));
251        }
252
253        scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
254        scores.truncate(k);
255
256        Ok(scores)
257    }
258}
259
260fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
261    let dot_product: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
262    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
263    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
264
265    if norm_a == 0.0 || norm_b == 0.0 {
266        return 0.0;
267    }
268
269    dot_product / (norm_a * norm_b)
270}
271
272/// PostgreSQL-backed vector store using pgvector native extension
273/// Uses HNSW index for fast approximate nearest-neighbor search
274/// Requires the `vector` extension: `CREATE EXTENSION IF NOT EXISTS vector;`
275#[cfg(feature = "postgres")]
276#[derive(Debug, Clone)]
277pub struct PgVectorStore {
278    dimension: usize,
279    pool: sqlx::PgPool,
280}
281
282#[cfg(feature = "postgres")]
283impl PgVectorStore {
284    pub fn new(dimension: usize, pool: sqlx::PgPool) -> Self {
285        Self { dimension, pool }
286    }
287}
288
289#[cfg(feature = "postgres")]
290#[async_trait]
291impl VectorStoreBackend for PgVectorStore {
292    async fn add(
293        &self,
294        id: String,
295        tenant_id: String,
296        vector: Vec<f32>,
297        metadata: HashMap<String, String>,
298    ) -> Result<(), VectorError> {
299        if vector.len() != self.dimension {
300            return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
301        }
302
303        let metadata_json = serde_json::to_string(&metadata)
304            .map_err(|e| VectorError::SerializationError(e.to_string()))?;
305
306        // Use pgvector::Vector type for native Postgres vector storage
307        let pg_vector = pgvector::Vector::from(vector);
308
309        sqlx::query(
310            "INSERT INTO vector_embeddings (id, tenant_id, vector, metadata) VALUES ($1, $2, $3::vector, $4)
311             ON CONFLICT (id, tenant_id) DO UPDATE SET vector = EXCLUDED.vector, metadata = EXCLUDED.metadata"
312        )
313        .bind(&id)
314        .bind(&tenant_id)
315        .bind(pg_vector)
316        .bind(metadata_json)
317        .execute(&self.pool)
318        .await
319        .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
320
321        Ok(())
322    }
323
324    async fn search(
325        &self,
326        tenant_id: &str,
327        query: &[f32],
328        k: usize,
329        filters: Option<HashMap<String, String>>,
330    ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
331        if query.len() != self.dimension {
332            return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
333        }
334
335        let pg_query = pgvector::Vector::from(query.to_vec());
336        let filters_json = filters
337            .as_ref()
338            .map(|f| serde_json::to_string(f).unwrap_or_else(|_| "{}".to_string()));
339
340        let rows = if let Some(fj) = filters_json {
341            sqlx::query(
342                "SELECT id, vector, metadata, 1 - (vector <=> $1::vector) AS score
343                 FROM vector_embeddings
344                 WHERE tenant_id = $2 AND metadata @> $3::jsonb
345                 ORDER BY vector <=> $1::vector
346                 LIMIT $4",
347            )
348            .bind(pg_query)
349            .bind(tenant_id)
350            .bind(fj)
351            .bind(k as i64)
352            .fetch_all(&self.pool)
353            .await
354        } else {
355            sqlx::query(
356                "SELECT id, vector, metadata, 1 - (vector <=> $1::vector) AS score
357                 FROM vector_embeddings
358                 WHERE tenant_id = $2
359                 ORDER BY vector <=> $1::vector
360                 LIMIT $3",
361            )
362            .bind(pg_query)
363            .bind(tenant_id)
364            .bind(k as i64)
365            .fetch_all(&self.pool)
366            .await
367        }
368        .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
369
370        let mut results = Vec::new();
371        for row in rows {
372            use sqlx::Row;
373            let id: String = row.get("id");
374            let metadata_str: String = row.get("metadata");
375            let score: f32 = row.try_get("score").unwrap_or(0.0);
376            let pg_vec: pgvector::Vector = row.get("vector");
377            let vector: Vec<f32> = pg_vec.to_vec();
378
379            let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
380                .map_err(|e| VectorError::SerializationError(e.to_string()))?;
381
382            results.push((
383                score,
384                VectorEmbedding {
385                    id,
386                    vector,
387                    metadata,
388                },
389            ));
390        }
391
392        Ok(results)
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[tokio::test]
401    async fn test_memory_vector_store_filtering() {
402        let store = MemoryVectorStore::new(3);
403        let tenant = "t1";
404
405        let mut m1 = HashMap::new();
406        m1.insert("type".to_string(), "a".to_string());
407        m1.insert("cat".to_string(), "1".to_string());
408
409        let mut m2 = HashMap::new();
410        m2.insert("type".to_string(), "b".to_string());
411
412        store
413            .add("1".into(), tenant.into(), vec![1.0, 0.0, 0.0], m1)
414            .await
415            .unwrap();
416        store
417            .add("2".into(), tenant.into(), vec![0.0, 1.0, 0.0], m2)
418            .await
419            .unwrap();
420
421        // 1. Filter by type=a
422        let mut filter = HashMap::new();
423        filter.insert("type".to_string(), "a".to_string());
424        let results = store
425            .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
426            .await
427            .unwrap();
428        assert_eq!(results.len(), 1);
429        assert_eq!(results[0].1.id, "1");
430
431        // 2. Filter by non-existent type
432        let mut filter = HashMap::new();
433        filter.insert("type".to_string(), "c".to_string());
434        let results = store
435            .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
436            .await
437            .unwrap();
438        assert_eq!(results.len(), 0);
439
440        // 3. Multi-filter
441        let mut filter = HashMap::new();
442        filter.insert("type".to_string(), "a".to_string());
443        filter.insert("cat".to_string(), "1".to_string());
444        let results = store
445            .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
446            .await
447            .unwrap();
448        assert_eq!(results.len(), 1);
449        assert_eq!(results[0].1.id, "1");
450    }
451
452    #[tokio::test]
453    async fn test_sqlite_vector_store_filtering() {
454        let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
455
456        // Setup table
457        sqlx::query("CREATE TABLE vector_embeddings (id TEXT PRIMARY KEY, tenant_id TEXT NOT NULL, vector BLOB NOT NULL, metadata JSON NOT NULL, created_at INTEGER NOT NULL)")
458            .execute(&pool).await.unwrap();
459
460        let store = SqliteVectorStore::new(3, pool);
461        let tenant = "t1";
462
463        let mut m1 = HashMap::new();
464        m1.insert("type".to_string(), "a".to_string());
465
466        let mut m2 = HashMap::new();
467        m2.insert("type".to_string(), "b".to_string());
468
469        store
470            .add("1".into(), tenant.into(), vec![1.0, 0.0, 0.0], m1)
471            .await
472            .unwrap();
473        store
474            .add("2".into(), tenant.into(), vec![0.0, 1.0, 0.0], m2)
475            .await
476            .unwrap();
477
478        // 1. Filter by type=a
479        let mut filter = HashMap::new();
480        filter.insert("type".to_string(), "a".to_string());
481        let results = store
482            .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
483            .await
484            .unwrap();
485        assert_eq!(results.len(), 1);
486        assert_eq!(results[0].1.id, "1");
487
488        // 2. Filter by non-existent type
489        let mut filter = HashMap::new();
490        filter.insert("type".to_string(), "c".to_string());
491        let results = store
492            .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
493            .await
494            .unwrap();
495        assert_eq!(results.len(), 0);
496    }
497}