sqlite_knowledge_graph/vector/
store.rs1use crate::error::{Error, Result};
4use rusqlite::params;
5
6pub struct VectorStore;
8
9impl Default for VectorStore {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl VectorStore {
16 pub fn new() -> Self {
18 Self
19 }
20
21 pub fn insert_vector(
23 &self,
24 conn: &rusqlite::Connection,
25 entity_id: i64,
26 vector: Vec<f32>,
27 ) -> Result<()> {
28 crate::graph::entity::get_entity(conn, entity_id)?;
30
31 if let Some(existing_dim) = self.check_dimension(conn)? {
33 if existing_dim != vector.len() {
34 return Err(Error::InvalidVectorDimension {
35 expected: existing_dim,
36 actual: vector.len(),
37 });
38 }
39 }
40
41 let mut bytes = Vec::with_capacity(vector.len() * 4);
43 for &val in &vector {
44 bytes.extend_from_slice(&val.to_le_bytes());
45 }
46
47 conn.execute(
48 r#"
49 INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
50 VALUES (?1, ?2, ?3)
51 "#,
52 params![entity_id, bytes, vector.len() as i64],
53 )?;
54
55 Ok(())
56 }
57
58 pub fn insert_vectors_batch(
60 &self,
61 conn: &rusqlite::Connection,
62 vectors: Vec<(i64, Vec<f32>)>,
63 ) -> Result<()> {
64 let tx = conn.unchecked_transaction()?;
65
66 for (entity_id, vector) in vectors {
67 let mut bytes = Vec::with_capacity(vector.len() * 4);
69 for &val in &vector {
70 bytes.extend_from_slice(&val.to_le_bytes());
71 }
72
73 tx.execute(
74 r#"
75 INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
76 VALUES (?1, ?2, ?3)
77 "#,
78 params![entity_id, bytes, vector.len() as i64],
79 )?;
80 }
81
82 tx.commit()?;
83 Ok(())
84 }
85
86 pub fn search_vectors(
88 &self,
89 conn: &rusqlite::Connection,
90 query: Vec<f32>,
91 k: usize,
92 ) -> Result<Vec<SearchResult>> {
93 if k == 0 {
94 return Ok(Vec::new());
95 }
96
97 let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
99
100 let mut results = Vec::new();
101
102 let rows = stmt.query_map([], |row| {
103 let entity_id: i64 = row.get(0)?;
104 let vector_blob: Vec<u8> = row.get(1)?;
105 let dimension: i64 = row.get(2)?;
106
107 let mut vector = Vec::with_capacity(dimension as usize);
109 for chunk in vector_blob.chunks_exact(4) {
110 let bytes: [u8; 4] = chunk.try_into().unwrap();
111 vector.push(f32::from_le_bytes(bytes));
112 }
113
114 Ok((entity_id, vector))
115 })?;
116
117 for row in rows {
118 let (entity_id, vector) = row?;
119
120 let similarity = cosine_similarity(&query, &vector);
122
123 results.push(SearchResult {
124 entity_id,
125 similarity,
126 });
127 }
128
129 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
131
132 Ok(results.into_iter().take(k).collect())
133 }
134
135 pub fn get_vector(&self, conn: &rusqlite::Connection, entity_id: i64) -> Result<Vec<f32>> {
137 let mut stmt =
138 conn.prepare("SELECT vector, dimension FROM kg_vectors WHERE entity_id = ?1")?;
139
140 let (vector_blob, dimension): (Vec<u8>, i64) =
141 stmt.query_row(params![entity_id], |row| Ok((row.get(0)?, row.get(1)?)))?;
142
143 let mut vector = Vec::with_capacity(dimension as usize);
145 for chunk in vector_blob.chunks_exact(4) {
146 let bytes: [u8; 4] = chunk.try_into().unwrap();
147 vector.push(f32::from_le_bytes(bytes));
148 }
149
150 Ok(vector)
151 }
152
153 fn check_dimension(&self, conn: &rusqlite::Connection) -> Result<Option<usize>> {
155 let mut stmt = conn.prepare("SELECT dimension FROM kg_vectors LIMIT 1")?;
156
157 let dimension = stmt.query_row([], |row| {
158 let dim: i64 = row.get(0)?;
159 Ok(Some(dim as usize))
160 });
161
162 match dimension {
163 Ok(dim) => Ok(dim),
164 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
165 Err(e) => Err(Error::SQLite(e)),
166 }
167 }
168}
169
170#[derive(Debug, Clone, serde::Serialize)]
172pub struct SearchResult {
173 pub entity_id: i64,
174 pub similarity: f32,
175}
176
177pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
179 if a.len() != b.len() {
180 return 0.0;
181 }
182
183 let mut dot_product = 0.0_f32;
184 let mut norm_a = 0.0_f32;
185 let mut norm_b = 0.0_f32;
186
187 for i in 0..a.len() {
188 dot_product += a[i] * b[i];
189 norm_a += a[i] * a[i];
190 norm_b += b[i] * b[i];
191 }
192
193 if norm_a == 0.0 || norm_b == 0.0 {
194 return 0.0;
195 }
196
197 dot_product / (norm_a.sqrt() * norm_b.sqrt())
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::graph::entity::{insert_entity, Entity};
204 use rusqlite::Connection;
205
206 #[test]
207 fn test_cosine_similarity() {
208 let a = vec![1.0, 2.0, 3.0];
209 let b = vec![1.0, 2.0, 3.0];
210 let sim = cosine_similarity(&a, &b);
211 assert!((sim - 1.0).abs() < 0.001);
212
213 let c = vec![0.0, 0.0, 0.0];
214 let sim = cosine_similarity(&a, &c);
215 assert_eq!(sim, 0.0);
216 }
217
218 #[test]
219 fn test_insert_vector() {
220 let conn = Connection::open_in_memory().unwrap();
221 crate::schema::create_schema(&conn).unwrap();
222
223 let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
224
225 let store = VectorStore::new();
226 let vector = vec![0.1, 0.2, 0.3, 0.4];
227
228 store.insert_vector(&conn, entity_id, vector).unwrap();
229 }
230
231 #[test]
232 fn test_search_vectors() {
233 let conn = Connection::open_in_memory().unwrap();
234 crate::schema::create_schema(&conn).unwrap();
235
236 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
237 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
238 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
239
240 let store = VectorStore::new();
241 let vector1 = vec![1.0, 0.0, 0.0];
242 let vector2 = vec![0.0, 1.0, 0.0];
243 let vector3 = vec![0.9, 0.1, 0.0];
244
245 store.insert_vector(&conn, entity1_id, vector1).unwrap();
246 store.insert_vector(&conn, entity2_id, vector2).unwrap();
247 store.insert_vector(&conn, entity3_id, vector3).unwrap();
248
249 let query = vec![1.0, 0.0, 0.0];
250 let results = store.search_vectors(&conn, query, 2).unwrap();
251
252 assert_eq!(results.len(), 2);
253 assert_eq!(results[0].entity_id, entity1_id);
254 assert_eq!(results[1].entity_id, entity3_id);
255 }
256
257 #[test]
258 fn test_invalid_dimension() {
259 let conn = Connection::open_in_memory().unwrap();
260 crate::schema::create_schema(&conn).unwrap();
261
262 let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
263
264 let store = VectorStore::new();
265 let vector1 = vec![0.1, 0.2, 0.3];
266 let vector2 = vec![0.1, 0.2, 0.3, 0.4];
267
268 store.insert_vector(&conn, entity_id, vector1).unwrap();
269
270 let result = store.insert_vector(&conn, entity_id, vector2);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_batch_insert() {
276 let conn = Connection::open_in_memory().unwrap();
277 crate::schema::create_schema(&conn).unwrap();
278
279 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
280 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
281
282 let store = VectorStore::new();
283 let vectors = vec![
284 (entity1_id, vec![0.1, 0.2, 0.3]),
285 (entity2_id, vec![0.4, 0.5, 0.6]),
286 ];
287
288 store.insert_vectors_batch(&conn, vectors).unwrap();
289
290 let query = vec![0.1, 0.2, 0.3];
291 let results = store.search_vectors(&conn, query, 10).unwrap();
292 assert_eq!(results.len(), 2);
293 }
294}