1#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{DbOptions, VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
29const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
30const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
31
32const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
34
35static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
38 Lazy::new(|| Mutex::new(HashMap::new()));
39
40pub struct VectorStorage {
42 db: Arc<Database>,
43 dimensions: usize,
44}
45
46impl VectorStorage {
47 pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
53 let path_ref = path.as_ref();
55
56 if let Some(parent) = path_ref.parent() {
58 if !parent.as_os_str().is_empty() && !parent.exists() {
59 std::fs::create_dir_all(parent).map_err(|e| {
60 RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
61 })?;
62 }
63 }
64
65 let path_buf = if path_ref.is_absolute() {
67 path_ref.to_path_buf()
68 } else {
69 std::env::current_dir()
70 .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
71 .join(path_ref)
72 };
73
74 let path_str = path_ref.to_string_lossy();
77 if path_str.contains("..") {
78 if !path_ref.is_absolute() {
82 if let Ok(cwd) = std::env::current_dir() {
83 let mut normalized = cwd.clone();
85 for component in path_ref.components() {
86 match component {
87 std::path::Component::ParentDir => {
88 if !normalized.pop() || !normalized.starts_with(&cwd) {
89 return Err(RuvectorError::InvalidPath(
90 "Path traversal attempt detected".to_string(),
91 ));
92 }
93 }
94 std::path::Component::Normal(c) => normalized.push(c),
95 _ => {}
96 }
97 }
98 }
99 }
100 }
101
102 let db = {
104 let mut pool = DB_POOL.lock();
105
106 if let Some(existing_db) = pool.get(&path_buf) {
107 Arc::clone(existing_db)
109 } else {
110 let new_db = Arc::new(Database::create(&path_buf)?);
112
113 let write_txn = new_db.begin_write()?;
115 {
116 let _ = write_txn.open_table(VECTORS_TABLE)?;
117 let _ = write_txn.open_table(METADATA_TABLE)?;
118 let _ = write_txn.open_table(CONFIG_TABLE)?;
119 }
120 write_txn.commit()?;
121
122 pool.insert(path_buf, Arc::clone(&new_db));
123 new_db
124 }
125 };
126
127 Ok(Self { db, dimensions })
128 }
129
130 pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
132 if entry.vector.len() != self.dimensions {
133 return Err(RuvectorError::DimensionMismatch {
134 expected: self.dimensions,
135 actual: entry.vector.len(),
136 });
137 }
138
139 let id = entry
140 .id
141 .clone()
142 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
143
144 let write_txn = self.db.begin_write()?;
145 {
146 let mut table = write_txn.open_table(VECTORS_TABLE)?;
147
148 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
150 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
151
152 table.insert(id.as_str(), vector_data.as_slice())?;
153
154 if let Some(metadata) = &entry.metadata {
156 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
157 let metadata_json = serde_json::to_string(metadata)
158 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
159 meta_table.insert(id.as_str(), metadata_json.as_str())?;
160 }
161 }
162 write_txn.commit()?;
163
164 Ok(id)
165 }
166
167 pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
169 let write_txn = self.db.begin_write()?;
170 let mut ids = Vec::with_capacity(entries.len());
171
172 {
173 let mut table = write_txn.open_table(VECTORS_TABLE)?;
174 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
175
176 for entry in entries {
177 if entry.vector.len() != self.dimensions {
178 return Err(RuvectorError::DimensionMismatch {
179 expected: self.dimensions,
180 actual: entry.vector.len(),
181 });
182 }
183
184 let id = entry
185 .id
186 .clone()
187 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
188
189 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
191 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
192 table.insert(id.as_str(), vector_data.as_slice())?;
193
194 if let Some(metadata) = &entry.metadata {
196 let metadata_json = serde_json::to_string(metadata)
197 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
198 meta_table.insert(id.as_str(), metadata_json.as_str())?;
199 }
200
201 ids.push(id);
202 }
203 }
204
205 write_txn.commit()?;
206 Ok(ids)
207 }
208
209 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
211 let read_txn = self.db.begin_read()?;
212 let table = read_txn.open_table(VECTORS_TABLE)?;
213
214 let Some(vector_data) = table.get(id)? else {
215 return Ok(None);
216 };
217
218 let (vector, _): (Vec<f32>, usize) =
219 bincode::decode_from_slice(vector_data.value(), config::standard())
220 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
221
222 let meta_table = read_txn.open_table(METADATA_TABLE)?;
224 let metadata = if let Some(meta_data) = meta_table.get(id)? {
225 let meta_str = meta_data.value();
226 Some(
227 serde_json::from_str(meta_str)
228 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
229 )
230 } else {
231 None
232 };
233
234 Ok(Some(VectorEntry {
235 id: Some(id.to_string()),
236 vector,
237 metadata,
238 }))
239 }
240
241 pub fn delete(&self, id: &str) -> Result<bool> {
243 let write_txn = self.db.begin_write()?;
244 let deleted;
245
246 {
247 let mut table = write_txn.open_table(VECTORS_TABLE)?;
248 deleted = table.remove(id)?.is_some();
249
250 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
251 let _ = meta_table.remove(id)?;
252 }
253
254 write_txn.commit()?;
255 Ok(deleted)
256 }
257
258 pub fn len(&self) -> Result<usize> {
260 let read_txn = self.db.begin_read()?;
261 let table = read_txn.open_table(VECTORS_TABLE)?;
262 Ok(table.len()? as usize)
263 }
264
265 pub fn is_empty(&self) -> Result<bool> {
267 Ok(self.len()? == 0)
268 }
269
270 pub fn all_ids(&self) -> Result<Vec<VectorId>> {
272 let read_txn = self.db.begin_read()?;
273 let table = read_txn.open_table(VECTORS_TABLE)?;
274
275 let mut ids = Vec::new();
276 let iter = table.iter()?;
277 for item in iter {
278 let (key, _) = item?;
279 ids.push(key.value().to_string());
280 }
281
282 Ok(ids)
283 }
284
285 pub fn save_config(&self, options: &DbOptions) -> Result<()> {
287 let config_json = serde_json::to_string(options)
288 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
289
290 let write_txn = self.db.begin_write()?;
291 {
292 let mut table = write_txn.open_table(CONFIG_TABLE)?;
293 table.insert(DB_CONFIG_KEY, config_json.as_str())?;
294 }
295 write_txn.commit()?;
296
297 Ok(())
298 }
299
300 pub fn load_config(&self) -> Result<Option<DbOptions>> {
302 let read_txn = self.db.begin_read()?;
303
304 let table = match read_txn.open_table(CONFIG_TABLE) {
306 Ok(t) => t,
307 Err(_) => return Ok(None),
308 };
309
310 let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
311 return Ok(None);
312 };
313
314 let config: DbOptions = serde_json::from_str(config_data.value())
315 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
316
317 Ok(Some(config))
318 }
319
320 pub fn dimensions(&self) -> usize {
322 self.dimensions
323 }
324}
325
326use uuid;
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tempfile::tempdir;
333
334 #[test]
335 fn test_insert_and_get() -> Result<()> {
336 let dir = tempdir().unwrap();
337 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
338
339 let entry = VectorEntry {
340 id: Some("test1".to_string()),
341 vector: vec![1.0, 2.0, 3.0],
342 metadata: None,
343 };
344
345 let id = storage.insert(&entry)?;
346 assert_eq!(id, "test1");
347
348 let retrieved = storage.get("test1")?;
349 assert!(retrieved.is_some());
350 let retrieved = retrieved.unwrap();
351 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
352
353 Ok(())
354 }
355
356 #[test]
357 fn test_batch_insert() -> Result<()> {
358 let dir = tempdir().unwrap();
359 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
360
361 let entries = vec![
362 VectorEntry {
363 id: None,
364 vector: vec![1.0, 2.0, 3.0],
365 metadata: None,
366 },
367 VectorEntry {
368 id: None,
369 vector: vec![4.0, 5.0, 6.0],
370 metadata: None,
371 },
372 ];
373
374 let ids = storage.insert_batch(&entries)?;
375 assert_eq!(ids.len(), 2);
376 assert_eq!(storage.len()?, 2);
377
378 Ok(())
379 }
380
381 #[test]
382 fn test_delete() -> Result<()> {
383 let dir = tempdir().unwrap();
384 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
385
386 let entry = VectorEntry {
387 id: Some("test1".to_string()),
388 vector: vec![1.0, 2.0, 3.0],
389 metadata: None,
390 };
391
392 storage.insert(&entry)?;
393 assert_eq!(storage.len()?, 1);
394
395 let deleted = storage.delete("test1")?;
396 assert!(deleted);
397 assert_eq!(storage.len()?, 0);
398
399 Ok(())
400 }
401
402 #[test]
403 fn test_multiple_instances_same_path() -> Result<()> {
404 let dir = tempdir().unwrap();
407 let db_path = dir.path().join("shared.db");
408
409 let storage1 = VectorStorage::new(&db_path, 3)?;
411
412 storage1.insert(&VectorEntry {
414 id: Some("test1".to_string()),
415 vector: vec![1.0, 2.0, 3.0],
416 metadata: None,
417 })?;
418
419 let storage2 = VectorStorage::new(&db_path, 3)?;
421
422 assert_eq!(storage1.len()?, 1);
424 assert_eq!(storage2.len()?, 1);
425
426 storage2.insert(&VectorEntry {
428 id: Some("test2".to_string()),
429 vector: vec![4.0, 5.0, 6.0],
430 metadata: None,
431 })?;
432
433 assert_eq!(storage1.len()?, 2);
435 assert_eq!(storage2.len()?, 2);
436
437 let retrieved1 = storage1.get("test1")?;
439 assert!(retrieved1.is_some());
440
441 let retrieved2 = storage2.get("test2")?;
442 assert!(retrieved2.is_some());
443
444 Ok(())
445 }
446}