Skip to main content

xz_embed/store/
sqlite_vec.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::Debug;
4
5use crate::error::StoreError;
6use crate::traits::{StoreLifecycle, VectorStore};
7use crate::types::{MetadataFilter, SearchResult, StoreStats, VectorEntry};
8
9/// sqlite-vec 向量存储实现
10///
11/// 特性:
12/// - 零外部依赖(通过 sqlx + sqlite)
13/// - 余弦距离搜索
14/// - 元数据过滤通过 SQL WHERE 子句实现
15/// - WAL 模式支持并发读
16#[derive(Debug)]
17pub struct SqliteVecStore {
18    pool: sqlx::SqlitePool,
19    dimensions: usize,
20    table_name: String,
21    max_capacity: Option<u64>,
22}
23
24impl SqliteVecStore {
25    /// 创建新的 sqlite-vec 存储
26    pub async fn new(
27        path: &str,
28        dimensions: usize,
29        max_pool_size: Option<usize>,
30    ) -> Result<Self, StoreError> {
31        let pool_size = max_pool_size.unwrap_or(5);
32        let conn_str = if path == ":memory:" {
33            "sqlite::memory:".to_string()
34        } else {
35            format!("sqlite:{path}")
36        };
37
38        let pool = sqlx::sqlite::SqlitePoolOptions::new()
39            .max_connections(pool_size as u32)
40            .connect(&conn_str)
41            .await
42            .map_err(|e| StoreError::Database(e.to_string()))?;
43
44        // 启用 WAL 模式
45        sqlx::query("PRAGMA journal_mode=WAL")
46            .execute(&pool)
47            .await
48            .map_err(|e| StoreError::Database(e.to_string()))?;
49
50        Ok(Self {
51            pool,
52            dimensions,
53            table_name: "embeddings".into(),
54            max_capacity: None,
55        })
56    }
57
58    /// 设置表名
59    pub fn with_table_name(mut self, name: &str) -> Self {
60        self.table_name = name.to_string();
61        self
62    }
63
64    /// 设置最大存储容量
65    pub fn with_max_capacity(mut self, capacity: Option<u64>) -> Self {
66        self.max_capacity = capacity;
67        self
68    }
69
70    /// 清理过期数据
71    pub async fn purge_expired(&self) -> Result<usize, StoreError> {
72        let now_ms = std::time::SystemTime::now()
73            .duration_since(std::time::UNIX_EPOCH)
74            .unwrap_or_default()
75            .as_millis() as u64;
76
77        let result = sqlx::query(&format!(
78            "DELETE FROM {} WHERE expires_at IS NOT NULL AND expires_at < ?",
79            self.table_name
80        ))
81        .bind(now_ms as i64)
82        .execute(&self.pool)
83        .await
84        .map_err(|e| StoreError::Database(e.to_string()))?;
85
86        Ok(result.rows_affected() as usize)
87    }
88
89    fn build_filter_clause(filter: &MetadataFilter) -> (String, Vec<String>) {
90        match filter {
91            MetadataFilter::Eq { key, value } => (
92                format!("json_extract(metadata_json, '$.{key}') = ?"),
93                vec![value.clone()],
94            ),
95            MetadataFilter::Ne { key, value } => (
96                format!("json_extract(metadata_json, '$.{key}') != ?"),
97                vec![value.clone()],
98            ),
99            MetadataFilter::In { key, values } => {
100                let placeholders: Vec<String> = values.iter().map(|_| "?".to_string()).collect();
101                (
102                    format!(
103                        "json_extract(metadata_json, '$.{key}') IN ({})",
104                        placeholders.join(", ")
105                    ),
106                    values.clone(),
107                )
108            }
109            MetadataFilter::NotIn { key, values } => {
110                let placeholders: Vec<String> = values.iter().map(|_| "?".to_string()).collect();
111                (
112                    format!(
113                        "json_extract(metadata_json, '$.{key}') NOT IN ({})",
114                        placeholders.join(", ")
115                    ),
116                    values.clone(),
117                )
118            }
119            MetadataFilter::Exists { key } => (
120                format!("json_extract(metadata_json, '$.{key}') IS NOT NULL"),
121                vec![],
122            ),
123            MetadataFilter::Contains { key, value } => (
124                format!("json_extract(metadata_json, '$.{key}') LIKE ?"),
125                vec![format!("%{value}%")],
126            ),
127            MetadataFilter::Range { key, min, max } => {
128                let mut clauses = Vec::new();
129                let mut params = Vec::new();
130                if let Some(min_val) = min {
131                    clauses.push(format!(
132                        "CAST(json_extract(metadata_json, '$.{key}') AS REAL) >= ?"
133                    ));
134                    params.push(min_val.to_string());
135                }
136                if let Some(max_val) = max {
137                    clauses.push(format!(
138                        "CAST(json_extract(metadata_json, '$.{key}') AS REAL) <= ?"
139                    ));
140                    params.push(max_val.to_string());
141                }
142                (clauses.join(" AND "), params)
143            }
144            MetadataFilter::And(filters) => {
145                let mut clauses = Vec::new();
146                let mut all_params = Vec::new();
147                for f in filters {
148                    let (clause, mut params) = Self::build_filter_clause(f);
149                    if !clause.is_empty() {
150                        clauses.push(format!("({clause})"));
151                        all_params.append(&mut params);
152                    }
153                }
154                (clauses.join(" AND "), all_params)
155            }
156            MetadataFilter::Or(filters) => {
157                let mut clauses = Vec::new();
158                let mut all_params = Vec::new();
159                for f in filters {
160                    let (clause, mut params) = Self::build_filter_clause(f);
161                    if !clause.is_empty() {
162                        clauses.push(format!("({clause})"));
163                        all_params.append(&mut params);
164                    }
165                }
166                (clauses.join(" OR "), all_params)
167            }
168            MetadataFilter::Not(filter) => {
169                let (inner, params) = Self::build_filter_clause(filter);
170                (format!("NOT ({inner})"), params)
171            }
172        }
173    }
174
175    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
176        if a.len() != b.len() {
177            return 0.0;
178        }
179        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
180        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
181        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
182        if norm_a == 0.0 || norm_b == 0.0 {
183            return 0.0;
184        }
185        dot / (norm_a * norm_b)
186    }
187
188    fn vector_to_blob(v: &[f32]) -> Vec<u8> {
189        v.iter().flat_map(|f| f.to_le_bytes()).collect()
190    }
191
192    fn blob_to_vector(b: &[u8]) -> Vec<f32> {
193        b.chunks_exact(4)
194            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
195            .collect()
196    }
197}
198
199#[async_trait]
200impl VectorStore for SqliteVecStore {
201    async fn insert(&self, entry: VectorEntry) -> Result<(), StoreError> {
202        self.insert_batch(vec![entry]).await
203    }
204
205    async fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<(), StoreError> {
206        if entries.is_empty() {
207            return Ok(());
208        }
209
210        // 检查维度
211        for entry in &entries {
212            if entry.vector.len() != self.dimensions {
213                return Err(StoreError::DimensionMismatch {
214                    expected: self.dimensions,
215                    actual: entry.vector.len(),
216                });
217            }
218        }
219
220        for entry in entries {
221            let vector_blob = Self::vector_to_blob(&entry.vector);
222            let metadata_json = serde_json::to_string(&entry.metadata)
223                .map_err(|e| StoreError::Serialization(e.to_string()))?;
224
225            sqlx::query(&format!(
226                "INSERT OR REPLACE INTO {} (id, content, metadata_json, channel, created_at, expires_at, embedding) VALUES (?, ?, ?, ?, ?, ?, ?)",
227                self.table_name
228            ))
229            .bind(&entry.id)
230            .bind(&entry.content)
231            .bind(&metadata_json)
232            .bind(&entry.channel)
233            .bind(entry.created_at as i64)
234            .bind(entry.expires_at.map(|t| t as i64))
235            .bind(&vector_blob)
236            .execute(&self.pool)
237            .await
238            .map_err(|e| StoreError::Database(e.to_string()))?;
239        }
240
241        Ok(())
242    }
243
244    async fn search(&self, query: &[f32], limit: usize) -> Result<Vec<SearchResult>, StoreError> {
245        if query.len() != self.dimensions {
246            return Err(StoreError::DimensionMismatch {
247                expected: self.dimensions,
248                actual: query.len(),
249            });
250        }
251
252        let rows = sqlx::query_as::<_, EmbeddingRow>(&format!(
253            "SELECT id, content, metadata_json, channel, embedding FROM {}",
254            self.table_name
255        ))
256        .fetch_all(&self.pool)
257        .await
258        .map_err(|e| StoreError::Database(e.to_string()))?;
259
260        let mut scored: Vec<(SearchResult, f32)> = rows
261            .iter()
262            .map(|row| {
263                let vector = Self::blob_to_vector(&row.embedding);
264                let similarity = Self::cosine_similarity(query, &vector);
265                let metadata: HashMap<String, String> =
266                    serde_json::from_str(&row.metadata_json).unwrap_or_default();
267
268                (
269                    SearchResult {
270                        id: row.id.clone(),
271                        score: similarity,
272                        metadata,
273                        content: row.content.clone(),
274                        channel: row.channel.clone(),
275                    },
276                    similarity,
277                )
278            })
279            .collect();
280
281        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282        scored.truncate(limit);
283
284        Ok(scored.into_iter().map(|(r, _)| r).collect())
285    }
286
287    async fn search_with_filter(
288        &self,
289        query: &[f32],
290        filter: &MetadataFilter,
291        limit: usize,
292    ) -> Result<Vec<SearchResult>, StoreError> {
293        if query.len() != self.dimensions {
294            return Err(StoreError::DimensionMismatch {
295                expected: self.dimensions,
296                actual: query.len(),
297            });
298        }
299
300        let (filter_clause, params) = Self::build_filter_clause(filter);
301        let sql = format!(
302            "SELECT id, content, metadata_json, channel, embedding FROM {} WHERE {}",
303            self.table_name, filter_clause
304        );
305
306        let mut query_builder = sqlx::query_as::<_, EmbeddingRow>(&sql);
307        for param in &params {
308            query_builder = query_builder.bind(param);
309        }
310
311        let rows = query_builder
312            .fetch_all(&self.pool)
313            .await
314            .map_err(|e| StoreError::Database(e.to_string()))?;
315
316        let mut scored: Vec<(SearchResult, f32)> = rows
317            .iter()
318            .map(|row| {
319                let vector = Self::blob_to_vector(&row.embedding);
320                let similarity = Self::cosine_similarity(query, &vector);
321                let metadata: HashMap<String, String> =
322                    serde_json::from_str(&row.metadata_json).unwrap_or_default();
323
324                (
325                    SearchResult {
326                        id: row.id.clone(),
327                        score: similarity,
328                        metadata,
329                        content: row.content.clone(),
330                        channel: row.channel.clone(),
331                    },
332                    similarity,
333                )
334            })
335            .collect();
336
337        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
338        scored.truncate(limit);
339
340        Ok(scored.into_iter().map(|(r, _)| r).collect())
341    }
342
343    async fn delete(&self, ids: &[String]) -> Result<usize, StoreError> {
344        if ids.is_empty() {
345            return Ok(0);
346        }
347        let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
348        let sql = format!(
349            "DELETE FROM {} WHERE id IN ({})",
350            self.table_name,
351            placeholders.join(", ")
352        );
353
354        let mut query = sqlx::query(&sql);
355        for id in ids {
356            query = query.bind(id);
357        }
358
359        let result = query
360            .execute(&self.pool)
361            .await
362            .map_err(|e| StoreError::Database(e.to_string()))?;
363
364        Ok(result.rows_affected() as usize)
365    }
366
367    async fn delete_by_filter(&self, filter: &MetadataFilter) -> Result<usize, StoreError> {
368        let (filter_clause, params) = Self::build_filter_clause(filter);
369        let sql = format!("DELETE FROM {} WHERE {}", self.table_name, filter_clause);
370
371        let mut query = sqlx::query(&sql);
372        for param in &params {
373            query = query.bind(param);
374        }
375
376        let result = query
377            .execute(&self.pool)
378            .await
379            .map_err(|e| StoreError::Database(e.to_string()))?;
380
381        Ok(result.rows_affected() as usize)
382    }
383
384    async fn clear(&self) -> Result<(), StoreError> {
385        sqlx::query(&format!("DELETE FROM {}", self.table_name))
386            .execute(&self.pool)
387            .await
388            .map_err(|e| StoreError::Database(e.to_string()))?;
389        Ok(())
390    }
391
392    async fn count(&self) -> Result<usize, StoreError> {
393        let (count,): (i64,) = sqlx::query_as(&format!(
394            "SELECT COUNT(*) FROM {}",
395            self.table_name
396        ))
397        .fetch_one(&self.pool)
398        .await
399        .map_err(|e| StoreError::Database(e.to_string()))?;
400
401        Ok(count as usize)
402    }
403
404    async fn rebuild_index(&self) -> Result<(), StoreError> {
405        // sqlite-vec 不依赖传统索引,此操作仅做 WAL checkpoint
406        sqlx::query("PRAGMA wal_checkpoint(FULL)")
407            .execute(&self.pool)
408            .await
409            .map_err(|e| StoreError::Database(e.to_string()))?;
410        Ok(())
411    }
412
413    async fn stats(&self) -> Result<StoreStats, StoreError> {
414        let count = self.count().await?;
415        Ok(StoreStats {
416            total_vectors: count,
417            total_dimensions: self.dimensions,
418            index_size_bytes: 0,
419            data_size_bytes: 0,
420            last_indexed_at: None,
421        })
422    }
423}
424
425#[async_trait]
426impl StoreLifecycle for SqliteVecStore {
427    async fn initialize(&self) -> Result<(), StoreError> {
428        sqlx::query(&format!(
429            "CREATE TABLE IF NOT EXISTS {} (
430                id TEXT PRIMARY KEY,
431                content TEXT,
432                metadata_json TEXT,
433                channel TEXT,
434                created_at INTEGER NOT NULL,
435                expires_at INTEGER,
436                embedding BLOB NOT NULL
437            )",
438            self.table_name
439        ))
440        .execute(&self.pool)
441        .await
442        .map_err(|e| StoreError::Database(e.to_string()))?;
443
444        Ok(())
445    }
446
447    async fn close(&self) -> Result<(), StoreError> {
448        self.pool.close().await;
449        Ok(())
450    }
451
452    async fn checkpoint(&self) -> Result<(), StoreError> {
453        sqlx::query("PRAGMA wal_checkpoint(FULL)")
454            .execute(&self.pool)
455            .await
456            .map_err(|e| StoreError::Database(e.to_string()))?;
457        Ok(())
458    }
459
460    async fn health_check(&self) -> Result<bool, StoreError> {
461        sqlx::query("SELECT 1")
462            .execute(&self.pool)
463            .await
464            .map_err(|e| StoreError::Database(e.to_string()))?;
465        Ok(true)
466    }
467}
468
469#[derive(Debug, sqlx::FromRow)]
470struct EmbeddingRow {
471    id: String,
472    content: Option<String>,
473    metadata_json: String,
474    channel: Option<String>,
475    embedding: Vec<u8>,
476}