1#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{DbOptions, VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
32
33const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
35
36static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
39 Lazy::new(|| Mutex::new(HashMap::new()));
40
41pub struct VectorStorage {
43 db: Arc<Database>,
44 dimensions: usize,
45}
46
47impl VectorStorage {
48 pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
54 let path_ref = path.as_ref();
56
57 if let Some(parent) = path_ref.parent() {
59 if !parent.as_os_str().is_empty() && !parent.exists() {
60 std::fs::create_dir_all(parent).map_err(|e| {
61 RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
62 })?;
63 }
64 }
65
66 let path_buf = if path_ref.is_absolute() {
68 path_ref.to_path_buf()
69 } else {
70 std::env::current_dir()
71 .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
72 .join(path_ref)
73 };
74
75 let path_str = path_ref.to_string_lossy();
78 if path_str.contains("..") {
79 if !path_ref.is_absolute() {
83 if let Ok(cwd) = std::env::current_dir() {
84 let mut normalized = cwd.clone();
86 for component in path_ref.components() {
87 match component {
88 std::path::Component::ParentDir => {
89 if !normalized.pop() || !normalized.starts_with(&cwd) {
90 return Err(RuvectorError::InvalidPath(
91 "Path traversal attempt detected".to_string()
92 ));
93 }
94 }
95 std::path::Component::Normal(c) => normalized.push(c),
96 _ => {}
97 }
98 }
99 }
100 }
101 }
102
103 let db = {
105 let mut pool = DB_POOL.lock();
106
107 if let Some(existing_db) = pool.get(&path_buf) {
108 Arc::clone(existing_db)
110 } else {
111 let new_db = Arc::new(Database::create(&path_buf)?);
113
114 let write_txn = new_db.begin_write()?;
116 {
117 let _ = write_txn.open_table(VECTORS_TABLE)?;
118 let _ = write_txn.open_table(METADATA_TABLE)?;
119 let _ = write_txn.open_table(CONFIG_TABLE)?;
120 }
121 write_txn.commit()?;
122
123 pool.insert(path_buf, Arc::clone(&new_db));
124 new_db
125 }
126 };
127
128 Ok(Self { db, dimensions })
129 }
130
131 pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
133 if entry.vector.len() != self.dimensions {
134 return Err(RuvectorError::DimensionMismatch {
135 expected: self.dimensions,
136 actual: entry.vector.len(),
137 });
138 }
139
140 let id = entry
141 .id
142 .clone()
143 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145 let write_txn = self.db.begin_write()?;
146 {
147 let mut table = write_txn.open_table(VECTORS_TABLE)?;
148
149 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
151 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
152
153 table.insert(id.as_str(), vector_data.as_slice())?;
154
155 if let Some(metadata) = &entry.metadata {
157 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
158 let metadata_json = serde_json::to_string(metadata)
159 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
160 meta_table.insert(id.as_str(), metadata_json.as_str())?;
161 }
162 }
163 write_txn.commit()?;
164
165 Ok(id)
166 }
167
168 pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
170 let write_txn = self.db.begin_write()?;
171 let mut ids = Vec::with_capacity(entries.len());
172
173 {
174 let mut table = write_txn.open_table(VECTORS_TABLE)?;
175 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
176
177 for entry in entries {
178 if entry.vector.len() != self.dimensions {
179 return Err(RuvectorError::DimensionMismatch {
180 expected: self.dimensions,
181 actual: entry.vector.len(),
182 });
183 }
184
185 let id = entry
186 .id
187 .clone()
188 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
189
190 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
192 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
193 table.insert(id.as_str(), vector_data.as_slice())?;
194
195 if let Some(metadata) = &entry.metadata {
197 let metadata_json = serde_json::to_string(metadata)
198 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
199 meta_table.insert(id.as_str(), metadata_json.as_str())?;
200 }
201
202 ids.push(id);
203 }
204 }
205
206 write_txn.commit()?;
207 Ok(ids)
208 }
209
210 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
212 let read_txn = self.db.begin_read()?;
213 let table = read_txn.open_table(VECTORS_TABLE)?;
214
215 let Some(vector_data) = table.get(id)? else {
216 return Ok(None);
217 };
218
219 let (vector, _): (Vec<f32>, usize) =
220 bincode::decode_from_slice(vector_data.value(), config::standard())
221 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
222
223 let meta_table = read_txn.open_table(METADATA_TABLE)?;
225 let metadata = if let Some(meta_data) = meta_table.get(id)? {
226 let meta_str = meta_data.value();
227 Some(
228 serde_json::from_str(meta_str)
229 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
230 )
231 } else {
232 None
233 };
234
235 Ok(Some(VectorEntry {
236 id: Some(id.to_string()),
237 vector,
238 metadata,
239 }))
240 }
241
242 pub fn delete(&self, id: &str) -> Result<bool> {
244 let write_txn = self.db.begin_write()?;
245 let mut deleted = false;
246
247 {
248 let mut table = write_txn.open_table(VECTORS_TABLE)?;
249 deleted = table.remove(id)?.is_some();
250
251 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
252 let _ = meta_table.remove(id)?;
253 }
254
255 write_txn.commit()?;
256 Ok(deleted)
257 }
258
259 pub fn len(&self) -> Result<usize> {
261 let read_txn = self.db.begin_read()?;
262 let table = read_txn.open_table(VECTORS_TABLE)?;
263 Ok(table.len()? as usize)
264 }
265
266 pub fn is_empty(&self) -> Result<bool> {
268 Ok(self.len()? == 0)
269 }
270
271 pub fn all_ids(&self) -> Result<Vec<VectorId>> {
273 let read_txn = self.db.begin_read()?;
274 let table = read_txn.open_table(VECTORS_TABLE)?;
275
276 let mut ids = Vec::new();
277 let iter = table.iter()?;
278 for item in iter {
279 let (key, _) = item?;
280 ids.push(key.value().to_string());
281 }
282
283 Ok(ids)
284 }
285
286 pub fn save_config(&self, options: &DbOptions) -> Result<()> {
288 let config_json = serde_json::to_string(options)
289 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
290
291 let write_txn = self.db.begin_write()?;
292 {
293 let mut table = write_txn.open_table(CONFIG_TABLE)?;
294 table.insert(DB_CONFIG_KEY, config_json.as_str())?;
295 }
296 write_txn.commit()?;
297
298 Ok(())
299 }
300
301 pub fn load_config(&self) -> Result<Option<DbOptions>> {
303 let read_txn = self.db.begin_read()?;
304
305 let table = match read_txn.open_table(CONFIG_TABLE) {
307 Ok(t) => t,
308 Err(_) => return Ok(None),
309 };
310
311 let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
312 return Ok(None);
313 };
314
315 let config: DbOptions = serde_json::from_str(config_data.value())
316 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
317
318 Ok(Some(config))
319 }
320
321 pub fn dimensions(&self) -> usize {
323 self.dimensions
324 }
325}
326
327use uuid;
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use tempfile::tempdir;
334
335 #[test]
336 fn test_insert_and_get() -> Result<()> {
337 let dir = tempdir().unwrap();
338 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
339
340 let entry = VectorEntry {
341 id: Some("test1".to_string()),
342 vector: vec![1.0, 2.0, 3.0],
343 metadata: None,
344 };
345
346 let id = storage.insert(&entry)?;
347 assert_eq!(id, "test1");
348
349 let retrieved = storage.get("test1")?;
350 assert!(retrieved.is_some());
351 let retrieved = retrieved.unwrap();
352 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
353
354 Ok(())
355 }
356
357 #[test]
358 fn test_batch_insert() -> Result<()> {
359 let dir = tempdir().unwrap();
360 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
361
362 let entries = vec![
363 VectorEntry {
364 id: None,
365 vector: vec![1.0, 2.0, 3.0],
366 metadata: None,
367 },
368 VectorEntry {
369 id: None,
370 vector: vec![4.0, 5.0, 6.0],
371 metadata: None,
372 },
373 ];
374
375 let ids = storage.insert_batch(&entries)?;
376 assert_eq!(ids.len(), 2);
377 assert_eq!(storage.len()?, 2);
378
379 Ok(())
380 }
381
382 #[test]
383 fn test_delete() -> Result<()> {
384 let dir = tempdir().unwrap();
385 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
386
387 let entry = VectorEntry {
388 id: Some("test1".to_string()),
389 vector: vec![1.0, 2.0, 3.0],
390 metadata: None,
391 };
392
393 storage.insert(&entry)?;
394 assert_eq!(storage.len()?, 1);
395
396 let deleted = storage.delete("test1")?;
397 assert!(deleted);
398 assert_eq!(storage.len()?, 0);
399
400 Ok(())
401 }
402
403 #[test]
404 fn test_multiple_instances_same_path() -> Result<()> {
405 let dir = tempdir().unwrap();
408 let db_path = dir.path().join("shared.db");
409
410 let storage1 = VectorStorage::new(&db_path, 3)?;
412
413 storage1.insert(&VectorEntry {
415 id: Some("test1".to_string()),
416 vector: vec![1.0, 2.0, 3.0],
417 metadata: None,
418 })?;
419
420 let storage2 = VectorStorage::new(&db_path, 3)?;
422
423 assert_eq!(storage1.len()?, 1);
425 assert_eq!(storage2.len()?, 1);
426
427 storage2.insert(&VectorEntry {
429 id: Some("test2".to_string()),
430 vector: vec![4.0, 5.0, 6.0],
431 metadata: None,
432 })?;
433
434 assert_eq!(storage1.len()?, 2);
436 assert_eq!(storage2.len()?, 2);
437
438 let retrieved1 = storage1.get("test1")?;
440 assert!(retrieved1.is_some());
441
442 let retrieved2 = storage2.get("test2")?;
443 assert!(retrieved2.is_some());
444
445 Ok(())
446 }
447}