sqlite_knowledge_graph/vector/
store.rs1use crate::error::{Error, Result};
4use rusqlite::params;
5
6pub struct VectorStore;
8
9impl Default for VectorStore {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl VectorStore {
16 pub fn new() -> Self {
18 Self
19 }
20
21 pub fn insert_vector(
23 &self,
24 conn: &rusqlite::Connection,
25 entity_id: i64,
26 vector: Vec<f32>,
27 ) -> Result<()> {
28 crate::graph::entity::get_entity(conn, entity_id)?;
30
31 if let Some(existing_dim) = self.check_dimension(conn)? {
33 if existing_dim != vector.len() {
34 return Err(Error::InvalidVectorDimension {
35 expected: existing_dim,
36 actual: vector.len(),
37 });
38 }
39 }
40
41 let mut bytes = Vec::with_capacity(vector.len() * 4);
43 for &val in &vector {
44 bytes.extend_from_slice(&val.to_le_bytes());
45 }
46
47 conn.execute(
48 r#"
49 INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
50 VALUES (?1, ?2, ?3)
51 "#,
52 params![entity_id, bytes, vector.len() as i64],
53 )?;
54
55 Ok(())
56 }
57
58 pub fn insert_vectors_batch(
60 &self,
61 conn: &rusqlite::Connection,
62 vectors: Vec<(i64, Vec<f32>)>,
63 ) -> Result<()> {
64 let tx = conn.unchecked_transaction()?;
65
66 for (entity_id, vector) in vectors {
67 crate::graph::entity::get_entity(conn, entity_id)?;
69
70 let mut bytes = Vec::with_capacity(vector.len() * 4);
72 for &val in &vector {
73 bytes.extend_from_slice(&val.to_le_bytes());
74 }
75
76 tx.execute(
77 r#"
78 INSERT OR REPLACE INTO kg_vectors (entity_id, vector, dimension)
79 VALUES (?1, ?2, ?3)
80 "#,
81 params![entity_id, bytes, vector.len() as i64],
82 )?;
83 }
84
85 tx.commit()?;
86 Ok(())
87 }
88
89 pub fn search_vectors(
91 &self,
92 conn: &rusqlite::Connection,
93 query: Vec<f32>,
94 k: usize,
95 ) -> Result<Vec<SearchResult>> {
96 if k == 0 {
97 return Ok(Vec::new());
98 }
99
100 let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
102
103 let mut results = Vec::new();
104
105 let rows = stmt.query_map([], |row| {
106 let entity_id: i64 = row.get(0)?;
107 let vector_blob: Vec<u8> = row.get(1)?;
108 let dimension: i64 = row.get(2)?;
109
110 let mut vector = Vec::with_capacity(dimension as usize);
112 for chunk in vector_blob.chunks_exact(4) {
113 let bytes: [u8; 4] = chunk.try_into().unwrap();
114 vector.push(f32::from_le_bytes(bytes));
115 }
116
117 Ok((entity_id, vector))
118 })?;
119
120 for row in rows {
121 let (entity_id, vector) = row?;
122
123 let similarity = cosine_similarity(&query, &vector);
125
126 results.push(SearchResult {
127 entity_id,
128 similarity,
129 });
130 }
131
132 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
134
135 Ok(results.into_iter().take(k).collect())
136 }
137
138 pub fn get_vector(&self, conn: &rusqlite::Connection, entity_id: i64) -> Result<Vec<f32>> {
140 let mut stmt =
141 conn.prepare("SELECT vector, dimension FROM kg_vectors WHERE entity_id = ?1")?;
142
143 let vector_blob: Vec<u8> = stmt.query_row(params![entity_id], |row| row.get(0))?;
144
145 let dimension: i64 = stmt.query_row(params![entity_id], |row| row.get(1))?;
146
147 let mut vector = Vec::with_capacity(dimension as usize);
149 for chunk in vector_blob.chunks_exact(4) {
150 let bytes: [u8; 4] = chunk.try_into().unwrap();
151 vector.push(f32::from_le_bytes(bytes));
152 }
153
154 Ok(vector)
155 }
156
157 fn check_dimension(&self, conn: &rusqlite::Connection) -> Result<Option<usize>> {
159 let mut stmt = conn.prepare("SELECT dimension FROM kg_vectors LIMIT 1")?;
160
161 let dimension = stmt.query_row([], |row| {
162 let dim: i64 = row.get(0)?;
163 Ok(Some(dim as usize))
164 });
165
166 match dimension {
167 Ok(dim) => Ok(dim),
168 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
169 Err(e) => Err(Error::SQLite(e)),
170 }
171 }
172}
173
174#[derive(Debug, Clone, serde::Serialize)]
176pub struct SearchResult {
177 pub entity_id: i64,
178 pub similarity: f32,
179}
180
181pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
183 if a.len() != b.len() {
184 return 0.0;
185 }
186
187 let mut dot_product = 0.0_f32;
188 let mut norm_a = 0.0_f32;
189 let mut norm_b = 0.0_f32;
190
191 for i in 0..a.len() {
192 dot_product += a[i] * b[i];
193 norm_a += a[i] * a[i];
194 norm_b += b[i] * b[i];
195 }
196
197 if norm_a == 0.0 || norm_b == 0.0 {
198 return 0.0;
199 }
200
201 dot_product / (norm_a.sqrt() * norm_b.sqrt())
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::graph::entity::{insert_entity, Entity};
208 use rusqlite::Connection;
209
210 #[test]
211 fn test_cosine_similarity() {
212 let a = vec![1.0, 2.0, 3.0];
213 let b = vec![1.0, 2.0, 3.0];
214 let sim = cosine_similarity(&a, &b);
215 assert!((sim - 1.0).abs() < 0.001);
216
217 let c = vec![0.0, 0.0, 0.0];
218 let sim = cosine_similarity(&a, &c);
219 assert_eq!(sim, 0.0);
220 }
221
222 #[test]
223 fn test_insert_vector() {
224 let conn = Connection::open_in_memory().unwrap();
225 crate::schema::create_schema(&conn).unwrap();
226
227 let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
228
229 let store = VectorStore::new();
230 let vector = vec![0.1, 0.2, 0.3, 0.4];
231
232 store.insert_vector(&conn, entity_id, vector).unwrap();
233 }
234
235 #[test]
236 fn test_search_vectors() {
237 let conn = Connection::open_in_memory().unwrap();
238 crate::schema::create_schema(&conn).unwrap();
239
240 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
241 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
242 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
243
244 let store = VectorStore::new();
245 let vector1 = vec![1.0, 0.0, 0.0];
246 let vector2 = vec![0.0, 1.0, 0.0];
247 let vector3 = vec![0.9, 0.1, 0.0];
248
249 store.insert_vector(&conn, entity1_id, vector1).unwrap();
250 store.insert_vector(&conn, entity2_id, vector2).unwrap();
251 store.insert_vector(&conn, entity3_id, vector3).unwrap();
252
253 let query = vec![1.0, 0.0, 0.0];
254 let results = store.search_vectors(&conn, query, 2).unwrap();
255
256 assert_eq!(results.len(), 2);
257 assert_eq!(results[0].entity_id, entity1_id);
258 assert_eq!(results[1].entity_id, entity3_id);
259 }
260
261 #[test]
262 fn test_invalid_dimension() {
263 let conn = Connection::open_in_memory().unwrap();
264 crate::schema::create_schema(&conn).unwrap();
265
266 let entity_id = insert_entity(&conn, &Entity::new("paper", "Test Paper")).unwrap();
267
268 let store = VectorStore::new();
269 let vector1 = vec![0.1, 0.2, 0.3];
270 let vector2 = vec![0.1, 0.2, 0.3, 0.4];
271
272 store.insert_vector(&conn, entity_id, vector1).unwrap();
273
274 let result = store.insert_vector(&conn, entity_id, vector2);
275 assert!(result.is_err());
276 }
277
278 #[test]
279 fn test_batch_insert() {
280 let conn = Connection::open_in_memory().unwrap();
281 crate::schema::create_schema(&conn).unwrap();
282
283 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
284 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
285
286 let store = VectorStore::new();
287 let vectors = vec![
288 (entity1_id, vec![0.1, 0.2, 0.3]),
289 (entity2_id, vec![0.4, 0.5, 0.6]),
290 ];
291
292 store.insert_vectors_batch(&conn, vectors).unwrap();
293
294 let query = vec![0.1, 0.2, 0.3];
295 let results = store.search_vectors(&conn, query, 10).unwrap();
296 assert_eq!(results.len(), 2);
297 }
298}