Skip to main content

sqlite_knowledge_graph/vector/
store.rs

1//! Vector storage module for semantic search.
2
3use crate::error::{Error, Result};
4use rusqlite::params;
5
6/// Represents a vector storage for embeddings.
7pub struct VectorStore;
8
9impl Default for VectorStore {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl VectorStore {
16    /// Create a new vector store.
17    pub fn new() -> Self {
18        Self
19    }
20
21    /// Insert a vector for an entity.
22    pub fn insert_vector(
23        &self,
24        conn: &rusqlite::Connection,
25        entity_id: i64,
26        vector: Vec<f32>,
27    ) -> Result<()> {
28        // Validate entity exists
29        crate::graph::entity::get_entity(conn, entity_id)?;
30
31        // Check if vector dimension matches existing vectors for consistency
32        if let Some(existing_dim) = self.check_dimension(conn)? {
33            if existing_dim != vector.len() {
34                return Err(Error::InvalidVectorDimension {
35                    expected: existing_dim,
36                    actual: vector.len(),
37                });
38            }
39        }
40
41        // Serialize vector to bytes
42        let mut bytes = Vec::with_capacity(vector.len() * 4);
43        for &val in &vector {
44            bytes.extend_from_slice(&val.to_le_bytes());
45        }
46
47        conn.execute(
48            r#"
49            INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
50            VALUES (?1, ?2, ?3)
51            "#,
52            params![entity_id, bytes, vector.len() as i64],
53        )?;
54
55        Ok(())
56    }
57
58    /// Batch insert vectors.
59    pub fn insert_vectors_batch(
60        &self,
61        conn: &rusqlite::Connection,
62        vectors: Vec<(i64, Vec<f32>)>,
63    ) -> Result<()> {
64        let tx = conn.unchecked_transaction()?;
65
66        for (entity_id, vector) in vectors {
67            // Serialize vector to bytes (FK constraint enforces entity existence)
68            let mut bytes = Vec::with_capacity(vector.len() * 4);
69            for &val in &vector {
70                bytes.extend_from_slice(&val.to_le_bytes());
71            }
72
73            tx.execute(
74                r#"
75                INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
76                VALUES (?1, ?2, ?3)
77                "#,
78                params![entity_id, bytes, vector.len() as i64],
79            )?;
80        }
81
82        tx.commit()?;
83        Ok(())
84    }
85
86    /// Search for similar vectors using cosine similarity.
87    pub fn search_vectors(
88        &self,
89        conn: &rusqlite::Connection,
90        query: Vec<f32>,
91        k: usize,
92    ) -> Result<Vec<SearchResult>> {
93        if k == 0 {
94            return Ok(Vec::new());
95        }
96
97        // Get all vectors
98        let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
99
100        let mut results = Vec::new();
101
102        let rows = stmt.query_map([], |row| {
103            let entity_id: i64 = row.get(0)?;
104            let vector_blob: Vec<u8> = row.get(1)?;
105            let dimension: i64 = row.get(2)?;
106
107            // Deserialize vector
108            let mut vector = Vec::with_capacity(dimension as usize);
109            for chunk in vector_blob.chunks_exact(4) {
110                let bytes: [u8; 4] = chunk.try_into().unwrap();
111                vector.push(f32::from_le_bytes(bytes));
112            }
113
114            Ok((entity_id, vector))
115        })?;
116
117        for row in rows {
118            let (entity_id, vector) = row?;
119
120            // Calculate cosine similarity
121            let similarity = cosine_similarity(&query, &vector);
122
123            results.push(SearchResult {
124                entity_id,
125                similarity,
126            });
127        }
128
129        // Sort by similarity (descending) and take top k
130        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
131
132        Ok(results.into_iter().take(k).collect())
133    }
134
135    /// Get vector for an entity.
136    pub fn get_vector(&self, conn: &rusqlite::Connection, entity_id: i64) -> Result<Vec<f32>> {
137        let mut stmt =
138            conn.prepare("SELECT vector, dimension FROM kg_vectors WHERE entity_id = ?1")?;
139
140        let (vector_blob, dimension): (Vec<u8>, i64) =
141            stmt.query_row(params![entity_id], |row| Ok((row.get(0)?, row.get(1)?)))?;
142
143        // Deserialize vector
144        let mut vector = Vec::with_capacity(dimension as usize);
145        for chunk in vector_blob.chunks_exact(4) {
146            let bytes: [u8; 4] = chunk.try_into().unwrap();
147            vector.push(f32::from_le_bytes(bytes));
148        }
149
150        Ok(vector)
151    }
152
153    /// Check if vectors exist and get their dimension.
154    fn check_dimension(&self, conn: &rusqlite::Connection) -> Result<Option<usize>> {
155        let mut stmt = conn.prepare("SELECT dimension FROM kg_vectors LIMIT 1")?;
156
157        let dimension = stmt.query_row([], |row| {
158            let dim: i64 = row.get(0)?;
159            Ok(Some(dim as usize))
160        });
161
162        match dimension {
163            Ok(dim) => Ok(dim),
164            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
165            Err(e) => Err(Error::SQLite(e)),
166        }
167    }
168}
169
170/// Represents a search result from vector similarity search.
171#[derive(Debug, Clone, serde::Serialize)]
172pub struct SearchResult {
173    pub entity_id: i64,
174    pub similarity: f32,
175}
176
177/// Calculate cosine similarity between two vectors.
178pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
179    if a.len() != b.len() {
180        return 0.0;
181    }
182
183    let mut dot_product = 0.0_f32;
184    let mut norm_a = 0.0_f32;
185    let mut norm_b = 0.0_f32;
186
187    for i in 0..a.len() {
188        dot_product += a[i] * b[i];
189        norm_a += a[i] * a[i];
190        norm_b += b[i] * b[i];
191    }
192
193    if norm_a == 0.0 || norm_b == 0.0 {
194        return 0.0;
195    }
196
197    dot_product / (norm_a.sqrt() * norm_b.sqrt())
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::graph::entity::{insert_entity, Entity};
204    use rusqlite::Connection;
205
206    #[test]
207    fn test_cosine_similarity() {
208        let a = vec![1.0, 2.0, 3.0];
209        let b = vec![1.0, 2.0, 3.0];
210        let sim = cosine_similarity(&a, &b);
211        assert!((sim - 1.0).abs() < 0.001);
212
213        let c = vec![0.0, 0.0, 0.0];
214        let sim = cosine_similarity(&a, &c);
215        assert_eq!(sim, 0.0);
216    }
217
218    #[test]
219    fn test_insert_vector() {
220        let conn = Connection::open_in_memory().unwrap();
221        crate::schema::create_schema(&conn).unwrap();
222
223        let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
224
225        let store = VectorStore::new();
226        let vector = vec![0.1, 0.2, 0.3, 0.4];
227
228        store.insert_vector(&conn, entity_id, vector).unwrap();
229    }
230
231    #[test]
232    fn test_search_vectors() {
233        let conn = Connection::open_in_memory().unwrap();
234        crate::schema::create_schema(&conn).unwrap();
235
236        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
237        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
238        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
239
240        let store = VectorStore::new();
241        let vector1 = vec![1.0, 0.0, 0.0];
242        let vector2 = vec![0.0, 1.0, 0.0];
243        let vector3 = vec![0.9, 0.1, 0.0];
244
245        store.insert_vector(&conn, entity1_id, vector1).unwrap();
246        store.insert_vector(&conn, entity2_id, vector2).unwrap();
247        store.insert_vector(&conn, entity3_id, vector3).unwrap();
248
249        let query = vec![1.0, 0.0, 0.0];
250        let results = store.search_vectors(&conn, query, 2).unwrap();
251
252        assert_eq!(results.len(), 2);
253        assert_eq!(results[0].entity_id, entity1_id);
254        assert_eq!(results[1].entity_id, entity3_id);
255    }
256
257    #[test]
258    fn test_invalid_dimension() {
259        let conn = Connection::open_in_memory().unwrap();
260        crate::schema::create_schema(&conn).unwrap();
261
262        let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
263
264        let store = VectorStore::new();
265        let vector1 = vec![0.1, 0.2, 0.3];
266        let vector2 = vec![0.1, 0.2, 0.3, 0.4];
267
268        store.insert_vector(&conn, entity_id, vector1).unwrap();
269
270        let result = store.insert_vector(&conn, entity_id, vector2);
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_batch_insert() {
276        let conn = Connection::open_in_memory().unwrap();
277        crate::schema::create_schema(&conn).unwrap();
278
279        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
280        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
281
282        let store = VectorStore::new();
283        let vectors = vec![
284            (entity1_id, vec![0.1, 0.2, 0.3]),
285            (entity2_id, vec![0.4, 0.5, 0.6]),
286        ];
287
288        store.insert_vectors_batch(&conn, vectors).unwrap();
289
290        let query = vec![0.1, 0.2, 0.3];
291        let results = store.search_vectors(&conn, query, 10).unwrap();
292        assert_eq!(results.len(), 2);
293    }
294}