Skip to main content

sqlite_knowledge_graph/vector/
store.rs

1//! Vector storage module for semantic search.
2
3use crate::error::{Error, Result};
4use rusqlite::params;
5
6/// Represents a vector storage for embeddings.
7pub struct VectorStore;
8
9impl Default for VectorStore {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl VectorStore {
16    /// Create a new vector store.
17    pub fn new() -> Self {
18        Self
19    }
20
21    /// Insert a vector for an entity.
22    pub fn insert_vector(
23        &self,
24        conn: &rusqlite::Connection,
25        entity_id: i64,
26        vector: Vec<f32>,
27    ) -> Result<()> {
28        // Validate entity exists
29        crate::graph::entity::get_entity(conn, entity_id)?;
30
31        // Check if vector dimension matches existing vectors for consistency
32        if let Some(existing_dim) = self.check_dimension(conn)? {
33            if existing_dim != vector.len() {
34                return Err(Error::InvalidVectorDimension {
35                    expected: existing_dim,
36                    actual: vector.len(),
37                });
38            }
39        }
40
41        // Serialize vector to bytes
42        let mut bytes = Vec::with_capacity(vector.len() * 4);
43        for &val in &vector {
44            bytes.extend_from_slice(&val.to_le_bytes());
45        }
46
47        conn.execute(
48            r#"
49            INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
50            VALUES (?1, ?2, ?3)
51            "#,
52            params![entity_id, bytes, vector.len() as i64],
53        )?;
54
55        Ok(())
56    }
57
58    /// Batch insert vectors.
59    pub fn insert_vectors_batch(
60        &self,
61        conn: &rusqlite::Connection,
62        vectors: Vec<(i64, Vec<f32>)>,
63    ) -> Result<()> {
64        let tx = conn.unchecked_transaction()?;
65
66        for (entity_id, vector) in vectors {
67            // Validate entity exists
68            crate::graph::entity::get_entity(conn, entity_id)?;
69
70            // Serialize vector to bytes
71            let mut bytes = Vec::with_capacity(vector.len() * 4);
72            for &val in &vector {
73                bytes.extend_from_slice(&val.to_le_bytes());
74            }
75
76            tx.execute(
77                r#"
78                INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
79                VALUES (?1, ?2, ?3)
80                "#,
81                params![entity_id, bytes, vector.len() as i64],
82            )?;
83        }
84
85        tx.commit()?;
86        Ok(())
87    }
88
89    /// Search for similar vectors using cosine similarity.
90    pub fn search_vectors(
91        &self,
92        conn: &rusqlite::Connection,
93        query: Vec<f32>,
94        k: usize,
95    ) -> Result<Vec<SearchResult>> {
96        if k == 0 {
97            return Ok(Vec::new());
98        }
99
100        // Get all vectors
101        let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
102
103        let mut results = Vec::new();
104
105        let rows = stmt.query_map([], |row| {
106            let entity_id: i64 = row.get(0)?;
107            let vector_blob: Vec<u8> = row.get(1)?;
108            let dimension: i64 = row.get(2)?;
109
110            // Deserialize vector
111            let mut vector = Vec::with_capacity(dimension as usize);
112            for chunk in vector_blob.chunks_exact(4) {
113                let bytes: [u8; 4] = chunk.try_into().unwrap();
114                vector.push(f32::from_le_bytes(bytes));
115            }
116
117            Ok((entity_id, vector))
118        })?;
119
120        for row in rows {
121            let (entity_id, vector) = row?;
122
123            // Calculate cosine similarity
124            let similarity = cosine_similarity(&query, &vector);
125
126            results.push(SearchResult {
127                entity_id,
128                similarity,
129            });
130        }
131
132        // Sort by similarity (descending) and take top k
133        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
134
135        Ok(results.into_iter().take(k).collect())
136    }
137
138    /// Get vector for an entity.
139    pub fn get_vector(&self, conn: &rusqlite::Connection, entity_id: i64) -> Result<Vec<f32>> {
140        let mut stmt =
141            conn.prepare("SELECT vector, dimension FROM kg_vectors WHERE entity_id = ?1")?;
142
143        let vector_blob: Vec<u8> = stmt.query_row(params![entity_id], |row| row.get(0))?;
144
145        let dimension: i64 = stmt.query_row(params![entity_id], |row| row.get(1))?;
146
147        // Deserialize vector
148        let mut vector = Vec::with_capacity(dimension as usize);
149        for chunk in vector_blob.chunks_exact(4) {
150            let bytes: [u8; 4] = chunk.try_into().unwrap();
151            vector.push(f32::from_le_bytes(bytes));
152        }
153
154        Ok(vector)
155    }
156
157    /// Check if vectors exist and get their dimension.
158    fn check_dimension(&self, conn: &rusqlite::Connection) -> Result<Option<usize>> {
159        let mut stmt = conn.prepare("SELECT dimension FROM kg_vectors LIMIT 1")?;
160
161        let dimension = stmt.query_row([], |row| {
162            let dim: i64 = row.get(0)?;
163            Ok(Some(dim as usize))
164        });
165
166        match dimension {
167            Ok(dim) => Ok(dim),
168            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
169            Err(e) => Err(Error::SQLite(e)),
170        }
171    }
172}
173
174/// Represents a search result from vector similarity search.
175#[derive(Debug, Clone, serde::Serialize)]
176pub struct SearchResult {
177    pub entity_id: i64,
178    pub similarity: f32,
179}
180
181/// Calculate cosine similarity between two vectors.
182pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
183    if a.len() != b.len() {
184        return 0.0;
185    }
186
187    let mut dot_product = 0.0_f32;
188    let mut norm_a = 0.0_f32;
189    let mut norm_b = 0.0_f32;
190
191    for i in 0..a.len() {
192        dot_product += a[i] * b[i];
193        norm_a += a[i] * a[i];
194        norm_b += b[i] * b[i];
195    }
196
197    if norm_a == 0.0 || norm_b == 0.0 {
198        return 0.0;
199    }
200
201    dot_product / (norm_a.sqrt() * norm_b.sqrt())
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::graph::entity::{insert_entity, Entity};
208    use rusqlite::Connection;
209
210    #[test]
211    fn test_cosine_similarity() {
212        let a = vec![1.0, 2.0, 3.0];
213        let b = vec![1.0, 2.0, 3.0];
214        let sim = cosine_similarity(&a, &b);
215        assert!((sim - 1.0).abs() < 0.001);
216
217        let c = vec![0.0, 0.0, 0.0];
218        let sim = cosine_similarity(&a, &c);
219        assert_eq!(sim, 0.0);
220    }
221
222    #[test]
223    fn test_insert_vector() {
224        let conn = Connection::open_in_memory().unwrap();
225        crate::schema::create_schema(&conn).unwrap();
226
227        let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
228
229        let store = VectorStore::new();
230        let vector = vec![0.1, 0.2, 0.3, 0.4];
231
232        store.insert_vector(&conn, entity_id, vector).unwrap();
233    }
234
235    #[test]
236    fn test_search_vectors() {
237        let conn = Connection::open_in_memory().unwrap();
238        crate::schema::create_schema(&conn).unwrap();
239
240        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
241        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
242        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
243
244        let store = VectorStore::new();
245        let vector1 = vec![1.0, 0.0, 0.0];
246        let vector2 = vec![0.0, 1.0, 0.0];
247        let vector3 = vec![0.9, 0.1, 0.0];
248
249        store.insert_vector(&conn, entity1_id, vector1).unwrap();
250        store.insert_vector(&conn, entity2_id, vector2).unwrap();
251        store.insert_vector(&conn, entity3_id, vector3).unwrap();
252
253        let query = vec![1.0, 0.0, 0.0];
254        let results = store.search_vectors(&conn, query, 2).unwrap();
255
256        assert_eq!(results.len(), 2);
257        assert_eq!(results[0].entity_id, entity1_id);
258        assert_eq!(results[1].entity_id, entity3_id);
259    }
260
261    #[test]
262    fn test_invalid_dimension() {
263        let conn = Connection::open_in_memory().unwrap();
264        crate::schema::create_schema(&conn).unwrap();
265
266        let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
267
268        let store = VectorStore::new();
269        let vector1 = vec![0.1, 0.2, 0.3];
270        let vector2 = vec![0.1, 0.2, 0.3, 0.4];
271
272        store.insert_vector(&conn, entity_id, vector1).unwrap();
273
274        let result = store.insert_vector(&conn, entity_id, vector2);
275        assert!(result.is_err());
276    }
277
278    #[test]
279    fn test_batch_insert() {
280        let conn = Connection::open_in_memory().unwrap();
281        crate::schema::create_schema(&conn).unwrap();
282
283        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
284        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
285
286        let store = VectorStore::new();
287        let vectors = vec![
288            (entity1_id, vec![0.1, 0.2, 0.3]),
289            (entity2_id, vec![0.4, 0.5, 0.6]),
290        ];
291
292        store.insert_vectors_batch(&conn, vectors).unwrap();
293
294        let query = vec![0.1, 0.2, 0.3];
295        let results = store.search_vectors(&conn, query, 10).unwrap();
296        assert_eq!(results.len(), 2);
297    }
298}