1#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31
32static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
35 Lazy::new(|| Mutex::new(HashMap::new()));
36
37pub struct VectorStorage {
39 db: Arc<Database>,
40 dimensions: usize,
41}
42
43impl VectorStorage {
44 pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
50 let path_buf = path
51 .as_ref()
52 .canonicalize()
53 .unwrap_or_else(|_| path.as_ref().to_path_buf());
54
55 let db = {
57 let mut pool = DB_POOL.lock();
58
59 if let Some(existing_db) = pool.get(&path_buf) {
60 Arc::clone(existing_db)
62 } else {
63 let new_db = Arc::new(Database::create(&path_buf)?);
65
66 let write_txn = new_db.begin_write()?;
68 {
69 let _ = write_txn.open_table(VECTORS_TABLE)?;
70 let _ = write_txn.open_table(METADATA_TABLE)?;
71 }
72 write_txn.commit()?;
73
74 pool.insert(path_buf, Arc::clone(&new_db));
75 new_db
76 }
77 };
78
79 Ok(Self { db, dimensions })
80 }
81
82 pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
84 if entry.vector.len() != self.dimensions {
85 return Err(RuvectorError::DimensionMismatch {
86 expected: self.dimensions,
87 actual: entry.vector.len(),
88 });
89 }
90
91 let id = entry
92 .id
93 .clone()
94 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
95
96 let write_txn = self.db.begin_write()?;
97 {
98 let mut table = write_txn.open_table(VECTORS_TABLE)?;
99
100 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
102 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
103
104 table.insert(id.as_str(), vector_data.as_slice())?;
105
106 if let Some(metadata) = &entry.metadata {
108 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
109 let metadata_json = serde_json::to_string(metadata)
110 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
111 meta_table.insert(id.as_str(), metadata_json.as_str())?;
112 }
113 }
114 write_txn.commit()?;
115
116 Ok(id)
117 }
118
119 pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
121 let write_txn = self.db.begin_write()?;
122 let mut ids = Vec::with_capacity(entries.len());
123
124 {
125 let mut table = write_txn.open_table(VECTORS_TABLE)?;
126 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
127
128 for entry in entries {
129 if entry.vector.len() != self.dimensions {
130 return Err(RuvectorError::DimensionMismatch {
131 expected: self.dimensions,
132 actual: entry.vector.len(),
133 });
134 }
135
136 let id = entry
137 .id
138 .clone()
139 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
140
141 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
143 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
144 table.insert(id.as_str(), vector_data.as_slice())?;
145
146 if let Some(metadata) = &entry.metadata {
148 let metadata_json = serde_json::to_string(metadata)
149 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
150 meta_table.insert(id.as_str(), metadata_json.as_str())?;
151 }
152
153 ids.push(id);
154 }
155 }
156
157 write_txn.commit()?;
158 Ok(ids)
159 }
160
161 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
163 let read_txn = self.db.begin_read()?;
164 let table = read_txn.open_table(VECTORS_TABLE)?;
165
166 let Some(vector_data) = table.get(id)? else {
167 return Ok(None);
168 };
169
170 let (vector, _): (Vec<f32>, usize) =
171 bincode::decode_from_slice(vector_data.value(), config::standard())
172 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
173
174 let meta_table = read_txn.open_table(METADATA_TABLE)?;
176 let metadata = if let Some(meta_data) = meta_table.get(id)? {
177 let meta_str = meta_data.value();
178 Some(
179 serde_json::from_str(meta_str)
180 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
181 )
182 } else {
183 None
184 };
185
186 Ok(Some(VectorEntry {
187 id: Some(id.to_string()),
188 vector,
189 metadata,
190 }))
191 }
192
193 pub fn delete(&self, id: &str) -> Result<bool> {
195 let write_txn = self.db.begin_write()?;
196 let mut deleted = false;
197
198 {
199 let mut table = write_txn.open_table(VECTORS_TABLE)?;
200 deleted = table.remove(id)?.is_some();
201
202 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
203 let _ = meta_table.remove(id)?;
204 }
205
206 write_txn.commit()?;
207 Ok(deleted)
208 }
209
210 pub fn len(&self) -> Result<usize> {
212 let read_txn = self.db.begin_read()?;
213 let table = read_txn.open_table(VECTORS_TABLE)?;
214 Ok(table.len()? as usize)
215 }
216
217 pub fn is_empty(&self) -> Result<bool> {
219 Ok(self.len()? == 0)
220 }
221
222 pub fn all_ids(&self) -> Result<Vec<VectorId>> {
224 let read_txn = self.db.begin_read()?;
225 let table = read_txn.open_table(VECTORS_TABLE)?;
226
227 let mut ids = Vec::new();
228 let iter = table.iter()?;
229 for item in iter {
230 let (key, _) = item?;
231 ids.push(key.value().to_string());
232 }
233
234 Ok(ids)
235 }
236}
237
238use uuid;
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use tempfile::tempdir;
245
246 #[test]
247 fn test_insert_and_get() -> Result<()> {
248 let dir = tempdir().unwrap();
249 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
250
251 let entry = VectorEntry {
252 id: Some("test1".to_string()),
253 vector: vec![1.0, 2.0, 3.0],
254 metadata: None,
255 };
256
257 let id = storage.insert(&entry)?;
258 assert_eq!(id, "test1");
259
260 let retrieved = storage.get("test1")?;
261 assert!(retrieved.is_some());
262 let retrieved = retrieved.unwrap();
263 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
264
265 Ok(())
266 }
267
268 #[test]
269 fn test_batch_insert() -> Result<()> {
270 let dir = tempdir().unwrap();
271 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
272
273 let entries = vec![
274 VectorEntry {
275 id: None,
276 vector: vec![1.0, 2.0, 3.0],
277 metadata: None,
278 },
279 VectorEntry {
280 id: None,
281 vector: vec![4.0, 5.0, 6.0],
282 metadata: None,
283 },
284 ];
285
286 let ids = storage.insert_batch(&entries)?;
287 assert_eq!(ids.len(), 2);
288 assert_eq!(storage.len()?, 2);
289
290 Ok(())
291 }
292
293 #[test]
294 fn test_delete() -> Result<()> {
295 let dir = tempdir().unwrap();
296 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
297
298 let entry = VectorEntry {
299 id: Some("test1".to_string()),
300 vector: vec![1.0, 2.0, 3.0],
301 metadata: None,
302 };
303
304 storage.insert(&entry)?;
305 assert_eq!(storage.len()?, 1);
306
307 let deleted = storage.delete("test1")?;
308 assert!(deleted);
309 assert_eq!(storage.len()?, 0);
310
311 Ok(())
312 }
313
314 #[test]
315 fn test_multiple_instances_same_path() -> Result<()> {
316 let dir = tempdir().unwrap();
319 let db_path = dir.path().join("shared.db");
320
321 let storage1 = VectorStorage::new(&db_path, 3)?;
323
324 storage1.insert(&VectorEntry {
326 id: Some("test1".to_string()),
327 vector: vec![1.0, 2.0, 3.0],
328 metadata: None,
329 })?;
330
331 let storage2 = VectorStorage::new(&db_path, 3)?;
333
334 assert_eq!(storage1.len()?, 1);
336 assert_eq!(storage2.len()?, 1);
337
338 storage2.insert(&VectorEntry {
340 id: Some("test2".to_string()),
341 vector: vec![4.0, 5.0, 6.0],
342 metadata: None,
343 })?;
344
345 assert_eq!(storage1.len()?, 2);
347 assert_eq!(storage2.len()?, 2);
348
349 let retrieved1 = storage1.get("test1")?;
351 assert!(retrieved1.is_some());
352
353 let retrieved2 = storage2.get("test2")?;
354 assert!(retrieved2.is_some());
355
356 Ok(())
357 }
358}