1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use sqlx::{Row, SqlitePool};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VectorError {
10 #[error("Dimension mismatch: expected {0}, got {1}")]
11 DimensionMismatch(usize, usize),
12 #[error("Serialization error: {0}")]
13 SerializationError(String),
14 #[error("Database error: {0}")]
15 DatabaseError(String),
16 #[error("Storage full: capacity exceeded")]
17 StorageFull,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VectorEmbedding {
22 pub id: String,
23 pub vector: Vec<f32>,
24 pub metadata: HashMap<String, String>,
25}
26
27#[async_trait]
29pub trait VectorStoreBackend: Send + Sync + std::fmt::Debug {
30 async fn add(
31 &self,
32 id: String,
33 tenant_id: String,
34 vector: Vec<f32>,
35 metadata: HashMap<String, String>,
36 ) -> Result<(), VectorError>;
37
38 async fn search(
39 &self,
40 tenant_id: &str,
41 query: &[f32],
42 k: usize,
43 filters: Option<HashMap<String, String>>,
44 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError>;
45}
46
47#[derive(Debug, Clone)]
49pub struct MemoryVectorStore {
50 dimension: usize,
51 embeddings: Arc<RwLock<Vec<(String, String, VectorEmbedding)>>>, }
53
54impl MemoryVectorStore {
55 pub fn new(dimension: usize) -> Self {
56 Self {
57 dimension,
58 embeddings: Arc::new(RwLock::new(Vec::new())),
59 }
60 }
61}
62
63#[async_trait]
64impl VectorStoreBackend for MemoryVectorStore {
65 async fn add(
66 &self,
67 id: String,
68 tenant_id: String,
69 vector: Vec<f32>,
70 metadata: HashMap<String, String>,
71 ) -> Result<(), VectorError> {
72 if vector.len() != self.dimension {
73 return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
74 }
75
76 let mut data = self.embeddings.write().unwrap();
77
78 if data.len() >= 100_000 {
80 return Err(VectorError::StorageFull);
81 }
82
83 data.push((
84 id.clone(),
85 tenant_id,
86 VectorEmbedding {
87 id,
88 vector,
89 metadata,
90 },
91 ));
92
93 Ok(())
94 }
95
96 async fn search(
97 &self,
98 tenant_id: &str,
99 query: &[f32],
100 k: usize,
101 filters: Option<HashMap<String, String>>,
102 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
103 if query.len() != self.dimension {
104 return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
105 }
106
107 let data = self.embeddings.read().unwrap();
108 let mut scores: Vec<(f32, VectorEmbedding)> = data
109 .iter()
110 .filter(|(_, tid, emb)| {
111 if tid != tenant_id {
112 return false;
113 }
114
115 if let Some(ref f) = filters {
117 for (key, val) in f {
118 if emb.metadata.get(key) != Some(val) {
119 return false;
120 }
121 }
122 }
123
124 true
125 })
126 .map(|(_, _, emb)| {
127 let score = cosine_similarity(query, &emb.vector);
128 (score, emb.clone())
129 })
130 .collect();
131
132 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
133 scores.truncate(k);
134
135 Ok(scores)
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct SqliteVectorStore {
142 dimension: usize,
143 pool: SqlitePool,
144}
145
146impl SqliteVectorStore {
147 pub fn new(dimension: usize, pool: SqlitePool) -> Self {
148 Self { dimension, pool }
149 }
150}
151
152#[async_trait]
153impl VectorStoreBackend for SqliteVectorStore {
154 async fn add(
155 &self,
156 id: String,
157 tenant_id: String,
158 vector: Vec<f32>,
159 metadata: HashMap<String, String>,
160 ) -> Result<(), VectorError> {
161 if vector.len() != self.dimension {
162 return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
163 }
164
165 let mut vector_bytes = Vec::with_capacity(vector.len() * 4);
167 for &val in &vector {
168 vector_bytes.extend_from_slice(&val.to_le_bytes());
169 }
170
171 let metadata_json = serde_json::to_string(&metadata)
172 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
173
174 sqlx::query(
175 "INSERT OR REPLACE INTO vector_embeddings (id, tenant_id, vector, metadata, created_at) VALUES (?, ?, ?, ?, ?)"
176 )
177 .bind(id)
178 .bind(tenant_id)
179 .bind(vector_bytes)
180 .bind(metadata_json)
181 .bind(chrono::Utc::now().timestamp())
182 .execute(&self.pool)
183 .await
184 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
185
186 Ok(())
187 }
188
189 async fn search(
190 &self,
191 tenant_id: &str,
192 query: &[f32],
193 k: usize,
194 filters: Option<HashMap<String, String>>,
195 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
196 if query.len() != self.dimension {
197 return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
198 }
199
200 let mut sql =
201 "SELECT id, vector, metadata FROM vector_embeddings WHERE tenant_id = ?".to_string();
202 if let Some(ref f) = filters {
203 for key in f.keys() {
204 sql.push_str(&format!(" AND json_extract(metadata, '$.{}') = ?", key));
205 }
206 }
207
208 let mut q = sqlx::query(&sql).bind(tenant_id);
209
210 if let Some(ref f) = filters {
211 for val in f.values() {
212 q = q.bind(val);
213 }
214 }
215
216 let rows = q
217 .fetch_all(&self.pool)
218 .await
219 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
220
221 let mut scores = Vec::new();
222
223 for row in rows {
224 let id: String = row.get("id");
225 let vector_bytes: Vec<u8> = row.get("vector");
226 let metadata_str: String = row.get("metadata");
227
228 if vector_bytes.len() != self.dimension * 4 {
230 continue; }
232
233 let mut vector = Vec::with_capacity(self.dimension);
234 for chunk in vector_bytes.chunks_exact(4) {
235 let arr: [u8; 4] = chunk.try_into().unwrap();
236 vector.push(f32::from_le_bytes(arr));
237 }
238
239 let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
240 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
241
242 let score = cosine_similarity(query, &vector);
243 scores.push((
244 score,
245 VectorEmbedding {
246 id,
247 vector,
248 metadata,
249 },
250 ));
251 }
252
253 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
254 scores.truncate(k);
255
256 Ok(scores)
257 }
258}
259
260fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
261 let dot_product: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
262 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
263 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
264
265 if norm_a == 0.0 || norm_b == 0.0 {
266 return 0.0;
267 }
268
269 dot_product / (norm_a * norm_b)
270}
271
272#[cfg(feature = "postgres")]
276#[derive(Debug, Clone)]
277pub struct PgVectorStore {
278 dimension: usize,
279 pool: sqlx::PgPool,
280}
281
282#[cfg(feature = "postgres")]
283impl PgVectorStore {
284 pub fn new(dimension: usize, pool: sqlx::PgPool) -> Self {
285 Self { dimension, pool }
286 }
287}
288
289#[cfg(feature = "postgres")]
290#[async_trait]
291impl VectorStoreBackend for PgVectorStore {
292 async fn add(
293 &self,
294 id: String,
295 tenant_id: String,
296 vector: Vec<f32>,
297 metadata: HashMap<String, String>,
298 ) -> Result<(), VectorError> {
299 if vector.len() != self.dimension {
300 return Err(VectorError::DimensionMismatch(self.dimension, vector.len()));
301 }
302
303 let metadata_json = serde_json::to_string(&metadata)
304 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
305
306 let pg_vector = pgvector::Vector::from(vector);
308
309 sqlx::query(
310 "INSERT INTO vector_embeddings (id, tenant_id, vector, metadata) VALUES ($1, $2, $3::vector, $4)
311 ON CONFLICT (id, tenant_id) DO UPDATE SET vector = EXCLUDED.vector, metadata = EXCLUDED.metadata"
312 )
313 .bind(&id)
314 .bind(&tenant_id)
315 .bind(pg_vector)
316 .bind(metadata_json)
317 .execute(&self.pool)
318 .await
319 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
320
321 Ok(())
322 }
323
324 async fn search(
325 &self,
326 tenant_id: &str,
327 query: &[f32],
328 k: usize,
329 filters: Option<HashMap<String, String>>,
330 ) -> Result<Vec<(f32, VectorEmbedding)>, VectorError> {
331 if query.len() != self.dimension {
332 return Err(VectorError::DimensionMismatch(self.dimension, query.len()));
333 }
334
335 let pg_query = pgvector::Vector::from(query.to_vec());
336 let filters_json = filters
337 .as_ref()
338 .map(|f| serde_json::to_string(f).unwrap_or_else(|_| "{}".to_string()));
339
340 let rows = if let Some(fj) = filters_json {
341 sqlx::query(
342 "SELECT id, vector, metadata, 1 - (vector <=> $1::vector) AS score
343 FROM vector_embeddings
344 WHERE tenant_id = $2 AND metadata @> $3::jsonb
345 ORDER BY vector <=> $1::vector
346 LIMIT $4",
347 )
348 .bind(pg_query)
349 .bind(tenant_id)
350 .bind(fj)
351 .bind(k as i64)
352 .fetch_all(&self.pool)
353 .await
354 } else {
355 sqlx::query(
356 "SELECT id, vector, metadata, 1 - (vector <=> $1::vector) AS score
357 FROM vector_embeddings
358 WHERE tenant_id = $2
359 ORDER BY vector <=> $1::vector
360 LIMIT $3",
361 )
362 .bind(pg_query)
363 .bind(tenant_id)
364 .bind(k as i64)
365 .fetch_all(&self.pool)
366 .await
367 }
368 .map_err(|e| VectorError::DatabaseError(e.to_string()))?;
369
370 let mut results = Vec::new();
371 for row in rows {
372 use sqlx::Row;
373 let id: String = row.get("id");
374 let metadata_str: String = row.get("metadata");
375 let score: f32 = row.try_get("score").unwrap_or(0.0);
376 let pg_vec: pgvector::Vector = row.get("vector");
377 let vector: Vec<f32> = pg_vec.to_vec();
378
379 let metadata: HashMap<String, String> = serde_json::from_str(&metadata_str)
380 .map_err(|e| VectorError::SerializationError(e.to_string()))?;
381
382 results.push((
383 score,
384 VectorEmbedding {
385 id,
386 vector,
387 metadata,
388 },
389 ));
390 }
391
392 Ok(results)
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[tokio::test]
401 async fn test_memory_vector_store_filtering() {
402 let store = MemoryVectorStore::new(3);
403 let tenant = "t1";
404
405 let mut m1 = HashMap::new();
406 m1.insert("type".to_string(), "a".to_string());
407 m1.insert("cat".to_string(), "1".to_string());
408
409 let mut m2 = HashMap::new();
410 m2.insert("type".to_string(), "b".to_string());
411
412 store
413 .add("1".into(), tenant.into(), vec![1.0, 0.0, 0.0], m1)
414 .await
415 .unwrap();
416 store
417 .add("2".into(), tenant.into(), vec![0.0, 1.0, 0.0], m2)
418 .await
419 .unwrap();
420
421 let mut filter = HashMap::new();
423 filter.insert("type".to_string(), "a".to_string());
424 let results = store
425 .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
426 .await
427 .unwrap();
428 assert_eq!(results.len(), 1);
429 assert_eq!(results[0].1.id, "1");
430
431 let mut filter = HashMap::new();
433 filter.insert("type".to_string(), "c".to_string());
434 let results = store
435 .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
436 .await
437 .unwrap();
438 assert_eq!(results.len(), 0);
439
440 let mut filter = HashMap::new();
442 filter.insert("type".to_string(), "a".to_string());
443 filter.insert("cat".to_string(), "1".to_string());
444 let results = store
445 .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
446 .await
447 .unwrap();
448 assert_eq!(results.len(), 1);
449 assert_eq!(results[0].1.id, "1");
450 }
451
452 #[tokio::test]
453 async fn test_sqlite_vector_store_filtering() {
454 let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
455
456 sqlx::query("CREATE TABLE vector_embeddings (id TEXT PRIMARY KEY, tenant_id TEXT NOT NULL, vector BLOB NOT NULL, metadata JSON NOT NULL, created_at INTEGER NOT NULL)")
458 .execute(&pool).await.unwrap();
459
460 let store = SqliteVectorStore::new(3, pool);
461 let tenant = "t1";
462
463 let mut m1 = HashMap::new();
464 m1.insert("type".to_string(), "a".to_string());
465
466 let mut m2 = HashMap::new();
467 m2.insert("type".to_string(), "b".to_string());
468
469 store
470 .add("1".into(), tenant.into(), vec![1.0, 0.0, 0.0], m1)
471 .await
472 .unwrap();
473 store
474 .add("2".into(), tenant.into(), vec![0.0, 1.0, 0.0], m2)
475 .await
476 .unwrap();
477
478 let mut filter = HashMap::new();
480 filter.insert("type".to_string(), "a".to_string());
481 let results = store
482 .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
483 .await
484 .unwrap();
485 assert_eq!(results.len(), 1);
486 assert_eq!(results[0].1.id, "1");
487
488 let mut filter = HashMap::new();
490 filter.insert("type".to_string(), "c".to_string());
491 let results = store
492 .search(tenant, &[1.0, 0.0, 0.0], 10, Some(filter))
493 .await
494 .unwrap();
495 assert_eq!(results.len(), 0);
496 }
497}